summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/stale.yml20
-rw-r--r--.travis.yml2
-rw-r--r--Makefile14
-rw-r--r--README.md1
-rw-r--r--benchmarks/workloads/absl/Dockerfile5
-rw-r--r--benchmarks/workloads/ruby_template/Gemfile.lock2
-rw-r--r--benchmarks/workloads/tensorflow/Dockerfile2
-rw-r--r--images/packetdrill/Dockerfile2
-rw-r--r--pkg/abi/linux/tcp.go1
-rw-r--r--pkg/flipcall/BUILD1
-rw-r--r--pkg/flipcall/flipcall.go2
-rw-r--r--pkg/flipcall/packet_window_mmap.go25
-rw-r--r--pkg/merkletree/BUILD16
-rw-r--r--pkg/merkletree/merkletree.go135
-rw-r--r--pkg/merkletree/merkletree_test.go122
-rw-r--r--pkg/p9/server.go4
-rw-r--r--pkg/sentry/arch/arch_aarch64.go10
-rw-r--r--pkg/sentry/control/BUILD4
-rw-r--r--pkg/sentry/control/proc.go84
-rw-r--r--pkg/sentry/devices/memdev/full.go1
-rw-r--r--pkg/sentry/devices/memdev/null.go1
-rw-r--r--pkg/sentry/devices/memdev/random.go1
-rw-r--r--pkg/sentry/devices/memdev/zero.go1
-rw-r--r--pkg/sentry/fs/file.go11
-rw-r--r--pkg/sentry/fs/fs.go3
-rw-r--r--pkg/sentry/fs/gofer/inode.go2
-rw-r--r--pkg/sentry/fs/host/inode.go3
-rw-r--r--pkg/sentry/fs/lock/lock.go41
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go8
-rw-r--r--pkg/sentry/fs/lock/lock_test.go111
-rw-r--r--pkg/sentry/fs/user/BUILD1
-rw-r--r--pkg/sentry/fs/user/path.go57
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go5
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go5
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go5
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go1
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go9
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go29
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go11
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go22
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go1
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go29
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go18
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go13
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD2
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go9
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go23
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go3
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/host.go8
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go9
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go13
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go6
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go5
-rw-r--r--pkg/sentry/fsimpl/proc/task.go5
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go7
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go10
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go5
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go1
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD1
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go7
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go26
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go27
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go112
-rw-r--r--pkg/sentry/kernel/auth/credentials.go28
-rw-r--r--pkg/sentry/kernel/fd_table.go15
-rw-r--r--pkg/sentry/kernel/kernel.go5
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go13
-rw-r--r--pkg/sentry/kernel/task_exec.go4
-rw-r--r--pkg/sentry/kernel/timekeeper.go2
-rw-r--r--pkg/sentry/loader/vdso.go2
-rw-r--r--pkg/sentry/pgalloc/BUILD22
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go216
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go198
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/address_space.go73
-rw-r--r--pkg/sentry/platform/kvm/bluepill_allocator.go (renamed from pkg/sentry/platform/kvm/allocator.go)52
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go13
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go48
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go10
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go7
-rw-r--r--pkg/sentry/platform/kvm/machine.go52
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go34
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s14
-rw-r--r--pkg/sentry/platform/ring0/kernel.go24
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go11
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go8
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go36
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go36
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go36
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go8
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go44
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go3
-rw-r--r--pkg/sentry/socket/netstack/stack.go9
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go2
-rw-r--r--pkg/sentry/socket/unix/unix.go39
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go39
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go145
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go6
-rw-r--r--pkg/sentry/time/muldiv_arm64.s3
-rw-r--r--pkg/sentry/time/parameters_test.go15
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/README.md4
-rw-r--r--pkg/sentry/vfs/epoll.go1
-rw-r--r--pkg/sentry/vfs/file_description.go25
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go80
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go1
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go3
-rw-r--r--pkg/sentry/vfs/inotify.go1
-rw-r--r--pkg/sentry/vfs/mount.go50
-rw-r--r--pkg/sentry/vfs/options.go4
-rw-r--r--pkg/sentry/vfs/vfs.go2
-rw-r--r--pkg/sentry/watchdog/watchdog.go10
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go7
-rw-r--r--pkg/tcpip/link/channel/channel.go8
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go10
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go2
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/loopback/loopback.go6
-rw-r--r--pkg/tcpip/link/muxed/injectable.go4
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go26
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go34
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go16
-rw-r--r--pkg/tcpip/network/arp/arp.go27
-rw-r--r--pkg/tcpip/network/arp/arp_test.go2
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go4
-rw-r--r--pkg/tcpip/network/ip_test.go66
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go15
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go115
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go40
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go15
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go57
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go126
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go8
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go43
-rw-r--r--pkg/tcpip/ports/ports.go116
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/conntrack.go50
-rw-r--r--pkg/tcpip/stack/forwarder.go4
-rw-r--r--pkg/tcpip/stack/forwarder_test.go112
-rw-r--r--pkg/tcpip/stack/iptables.go44
-rw-r--r--pkg/tcpip/stack/iptables_targets.go5
-rw-r--r--pkg/tcpip/stack/iptables_types.go17
-rw-r--r--pkg/tcpip/stack/ndp.go100
-rw-r--r--pkg/tcpip/stack/ndp_test.go14
-rw-r--r--pkg/tcpip/stack/nic.go104
-rw-r--r--pkg/tcpip/stack/nic_test.go259
-rw-r--r--pkg/tcpip/stack/packet_buffer.go33
-rw-r--r--pkg/tcpip/stack/registration.go39
-rw-r--r--pkg/tcpip/stack/route.go8
-rw-r--r--pkg/tcpip/stack/stack.go46
-rw-r--r--pkg/tcpip/stack/stack_test.go94
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go109
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go7
-rw-r--r--pkg/tcpip/stack/transport_test.go48
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go21
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go12
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go16
-rw-r--r--pkg/tcpip/transport/tcp/BUILD4
-rw-r--r--pkg/tcpip/transport/tcp/accept.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go65
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go51
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go25
-rw-r--r--pkg/tcpip/transport/tcp/segment.go37
-rw-r--r--pkg/tcpip/transport/tcp/snd.go50
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go10
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go92
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go7
-rw-r--r--pkg/tcpip/transport/udp/protocol.go51
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go64
-rw-r--r--pkg/tmutex/BUILD17
-rw-r--r--pkg/tmutex/tmutex.go81
-rw-r--r--pkg/tmutex/tmutex_test.go258
-rw-r--r--runsc/boot/filter/extra_filters_msan.go2
-rw-r--r--runsc/boot/fs.go6
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/vfs.go8
-rw-r--r--runsc/cgroup/BUILD4
-rw-r--r--runsc/cgroup/cgroup.go43
-rw-r--r--runsc/cgroup/cgroup_test.go582
-rw-r--r--runsc/cmd/gofer.go11
-rw-r--r--runsc/container/container_test.go272
-rw-r--r--runsc/container/multi_container_test.go45
-rwxr-xr-xscripts/common_build.sh4
-rwxr-xr-xscripts/packetdrill_tests.sh3
-rwxr-xr-xscripts/packetimpact_tests.sh3
-rw-r--r--test/e2e/integration_test.go4
-rw-r--r--test/iptables/iptables_test.go4
-rw-r--r--test/iptables/nat.go89
-rw-r--r--test/packetimpact/testbench/connections.go263
-rw-r--r--test/packetimpact/testbench/layers.go5
-rw-r--r--test/packetimpact/tests/BUILD26
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go20
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go30
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go22
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go36
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go84
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go12
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go28
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go22
-rw-r--r--test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go18
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go22
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go105
-rw-r--r--test/packetimpact/tests/tcp_should_piggyback_test.go64
-rw-r--r--test/packetimpact/tests/tcp_splitseg_mss_test.go71
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go18
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go22
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go24
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go33
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go38
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go30
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go69
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go12
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go17
-rw-r--r--test/runner/defs.bzl1
-rw-r--r--test/runtimes/proctor/BUILD2
-rw-r--r--test/syscalls/BUILD16
-rw-r--r--test/syscalls/linux/BUILD43
-rw-r--r--test/syscalls/linux/accept_bind.cc42
-rw-r--r--test/syscalls/linux/exec.cc27
-rw-r--r--test/syscalls/linux/exec_binary.cc5
-rw-r--r--test/syscalls/linux/flock.cc75
-rw-r--r--test/syscalls/linux/fpsig_fork.cc6
-rw-r--r--test/syscalls/linux/fpsig_nested.cc49
-rw-r--r--test/syscalls/linux/open.cc23
-rw-r--r--test/syscalls/linux/ping_socket.cc91
-rw-r--r--test/syscalls/linux/poll.cc6
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc29
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc9
-rw-r--r--test/syscalls/linux/socket_inet_loopback_nogotsan.cc171
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc73
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc6
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc63
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc56
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket.cc18
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc14
-rw-r--r--test/syscalls/linux/tuntap.cc20
-rw-r--r--tools/defs.bzl1
-rw-r--r--tools/go_generics/globals/scope.go4
287 files changed, 5808 insertions, 2735 deletions
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
new file mode 100644
index 000000000..0b31fecf5
--- /dev/null
+++ b/.github/workflows/stale.yml
@@ -0,0 +1,20 @@
+name: "Close stale issues"
+on:
+ schedule:
+ - cron: "0 0 * * *"
+
+jobs:
+ stale:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/stale@v3
+ with:
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
+ stale-issue-label: 'stale'
+ stale-pr-label: 'stale'
+ exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question'
+ exempt-pr-labels: 'ready to pull'
+ stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ days-before-stale: 90
+ days-before-close: 30
diff --git a/.travis.yml b/.travis.yml
index fbc0e46d7..4cb202b7d 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -16,7 +16,7 @@ script:
# On arm64, we need to create our own pipes for stderr and stdout,
# otherwise we will not be able to open /dev/stderr. This is probably
# due to AppArmor rules.
- - uname -a && make smoke-test 2>&1 | cat
+ - bash -xeo pipefail -c 'uname -a && make smoke-test 2>&1 | cat'
branches:
except:
# Skip copybara branches.
diff --git a/Makefile b/Makefile
index d6cd9ec03..85818ebea 100644
--- a/Makefile
+++ b/Makefile
@@ -171,12 +171,20 @@ RELEASE_COMMIT :=
RELEASE_NAME :=
RELEASE_NOTES :=
+GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi)
$(RELEASE_KEY):
@echo "WARNING: Generating a key for testing ($@); don't use this."
T=$$(mktemp /tmp/keyring.XXXXXX); \
- gpg --no-default-keyring --keyring $$T --batch --passphrase "" --quick-generate-key $(shell whoami) && \
- gpg --export-secret-keys --no-default-keyring --keyring $$T > $@; \
- rc=$$?; rm -f $$T; exit $$rc
+ C=$$(mktemp /tmp/config.XXXXXX); \
+ echo Key-Type: DSA >> $$C && \
+ echo Key-Length: 1024 >> $$C && \
+ echo Name-Real: Test >> $$C && \
+ echo Name-Email: test@example.com >> $$C && \
+ echo Expire-Date: 0 >> $$C && \
+ echo %commit >> $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \
+ rc=$$?; rm -f $$T $$C; exit $$rc
release: $(RELEASE_KEY) ## Builds a release.
@mkdir -p $(RELEASE_ROOT)
diff --git a/README.md b/README.md
index ce3947907..d72d1dac4 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
![gVisor](g3doc/logo.png)
![](https://github.com/google/gvisor/workflows/Build/badge.svg)
+[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community)
## What is gVisor?
diff --git a/benchmarks/workloads/absl/Dockerfile b/benchmarks/workloads/absl/Dockerfile
index e935c5ddc..f29cfa156 100644
--- a/benchmarks/workloads/absl/Dockerfile
+++ b/benchmarks/workloads/absl/Dockerfile
@@ -16,9 +16,10 @@ RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27
RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh
RUN ./bazel-0.27.0-installer-linux-x86_64.sh
-RUN git clone https://github.com/abseil/abseil-cpp.git
+RUN mkdir abseil-cpp && cd abseil-cpp \
+ && git init && git remote add origin https://github.com/abseil/abseil-cpp.git \
+ && git fetch --depth 1 origin 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f && git checkout FETCH_HEAD
WORKDIR abseil-cpp
-RUN git checkout 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f
RUN bazel clean
ENV path "absl/base/..."
CMD bazel build ${path} 2>&1
diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock
index f637b6081..eeb3c7bbe 100644
--- a/benchmarks/workloads/ruby_template/Gemfile.lock
+++ b/benchmarks/workloads/ruby_template/Gemfile.lock
@@ -2,7 +2,7 @@ GEM
remote: https://rubygems.org/
specs:
mustermann (1.0.3)
- puma (3.12.4)
+ puma (3.12.6)
rack (2.0.6)
rack-protection (2.0.5)
rack
diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile
index 262643b98..b5763e8ae 100644
--- a/benchmarks/workloads/tensorflow/Dockerfile
+++ b/benchmarks/workloads/tensorflow/Dockerfile
@@ -2,7 +2,7 @@ FROM tensorflow/tensorflow:1.13.2
RUN apt-get update \
&& apt-get install -y git
-RUN git clone https://github.com/aymericdamien/TensorFlow-Examples.git
+RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git
RUN python -m pip install -U pip setuptools
RUN python -m pip install matplotlib
diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile
index 7a006c85f..01296dbaf 100644
--- a/images/packetdrill/Dockerfile
+++ b/images/packetdrill/Dockerfile
@@ -2,7 +2,7 @@ FROM ubuntu:bionic
RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
netcat tcpdump jq tar bison flex make
RUN hash -r
-RUN git clone --branch packetdrill-v2.0 \
+RUN git clone --depth 1 --branch packetdrill-v2.0 \
https://github.com/google/packetdrill.git
RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
CMD /bin/bash
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
index 174d470e2..2a8d4708b 100644
--- a/pkg/abi/linux/tcp.go
+++ b/pkg/abi/linux/tcp.go
@@ -57,4 +57,5 @@ const (
const (
MAX_TCP_KEEPIDLE = 32767
MAX_TCP_KEEPINTVL = 32767
+ MAX_TCP_KEEPCNT = 127
)
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 9c5ad500b..aa8e4e1f3 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,6 +11,7 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
+ "packet_window_mmap.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 3cdb576e1..ec742c091 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -95,7 +95,7 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...
if pwd.Length > math.MaxUint32 {
return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
}
- m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ m, e := packetWindowMmap(pwd)
if e != 0 {
return fmt.Errorf("failed to mmap packet window: %v", e)
}
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap.go
new file mode 100644
index 000000000..869183b11
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
new file mode 100644
index 000000000..5b0e4143a
--- /dev/null
+++ b/pkg/merkletree/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "merkletree",
+ srcs = ["merkletree.go"],
+ deps = ["//pkg/usermem"],
+)
+
+go_test(
+ name = "merkletree_test",
+ srcs = ["merkletree_test.go"],
+ library = ":merkletree",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
new file mode 100644
index 000000000..906f67943
--- /dev/null
+++ b/pkg/merkletree/merkletree.go
@@ -0,0 +1,135 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package merkletree implements Merkle tree generating and verification.
+package merkletree
+
+import (
+ "crypto/sha256"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // sha256DigestSize specifies the digest size of a SHA256 hash.
+ sha256DigestSize = 32
+)
+
+// Size defines the scale of a Merkle tree.
+type Size struct {
+ // blockSize is the size of a data block to be hashed.
+ blockSize int64
+ // digestSize is the size of a generated hash.
+ digestSize int64
+ // hashesPerBlock is the number of hashes in a block. For example, if
+ // blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128
+ // hashesPerBlock. Therefore 128 hashes in a lower level will be put into a
+ // block and generate a single hash in an upper level.
+ hashesPerBlock int64
+ // levelStart is the start block index of each level. The number of levels in
+ // the tree is the length of the slice. The leafs (level 0) are hashes of
+ // blocks in the input data. The levels above are hashes of lower level
+ // hashes. The highest level is the root hash.
+ levelStart []int64
+}
+
+// MakeSize initializes and returns a new Size object describing the structure
+// of a tree. dataSize specifies the number of the file system size in bytes.
+func MakeSize(dataSize int64) Size {
+ size := Size{
+ blockSize: usermem.PageSize,
+ // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
+ digestSize: sha256DigestSize,
+ hashesPerBlock: usermem.PageSize / sha256DigestSize,
+ }
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+ level := int64(0)
+ offset := int64(0)
+
+ // Calcuate the number of levels in the Merkle tree and the beginning offset
+ // of each level. Level 0 is the level directly above the data blocks, while
+ // level NumLevels - 1 is the root.
+ for numBlocks > 1 {
+ size.levelStart = append(size.levelStart, offset)
+ // Round numBlocks up to fill up a block.
+ numBlocks += (size.hashesPerBlock - numBlocks%size.hashesPerBlock) % size.hashesPerBlock
+ offset += numBlocks / size.hashesPerBlock
+ numBlocks = numBlocks / size.hashesPerBlock
+ level++
+ }
+ size.levelStart = append(size.levelStart, offset)
+ return size
+}
+
+// Generate constructs a Merkle tree for the contents of data. The output is
+// written to treeWriter. The treeReader should be able to read the tree after
+// it has been written. That is, treeWriter and treeReader should point to the
+// same underlying data but have separate cursors.
+func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) {
+ size := MakeSize(dataSize)
+
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+
+ var root []byte
+ for level := 0; level < len(size.levelStart); level++ {
+ for i := int64(0); i < numBlocks; i++ {
+ buf := make([]byte, size.blockSize)
+ var (
+ n int
+ err error
+ )
+ if level == 0 {
+ // Read data block from the target file since level 0 is directly above
+ // the raw data block.
+ n, err = data.Read(buf)
+ } else {
+ // Read data block from the tree file since levels higher than 0 are
+ // hashing the lower level hashes.
+ n, err = treeReader.Read(buf)
+ }
+
+ // err is populated as long as the bytes read is smaller than the buffer
+ // size. This could be the case if we are reading the last block, and
+ // break in that case. If this is the last block, the end of the block
+ // will be zero-padded.
+ if n == 0 && err == io.EOF {
+ break
+ } else if err != nil && err != io.EOF {
+ return nil, err
+ }
+ // Hash the bytes in buf.
+ digest := sha256.Sum256(buf)
+
+ if level == len(size.levelStart)-1 {
+ root = digest[:]
+ }
+
+ // Write the generated hash to the end of the tree file.
+ if _, err = treeWriter.Write(digest[:]); err != nil {
+ return nil, err
+ }
+ }
+ // If the genereated digests do not round up to a block, zero-padding the
+ // remaining of the last block. But no need to do so for root.
+ if level != len(size.levelStart)-1 && numBlocks%size.hashesPerBlock != 0 {
+ zeroBuf := make([]byte, size.blockSize-(numBlocks%size.hashesPerBlock)*size.digestSize)
+ if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ return nil, err
+ }
+ }
+ numBlocks = (numBlocks + size.hashesPerBlock - 1) / size.hashesPerBlock
+ }
+ return root, nil
+}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
new file mode 100644
index 000000000..7344db0b6
--- /dev/null
+++ b/pkg/merkletree/merkletree_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package merkletree
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSize(t *testing.T) {
+ testCases := []struct {
+ dataSize int64
+ expectedLevelStart []int64
+ }{
+ {
+ dataSize: 100,
+ expectedLevelStart: []int64{0},
+ },
+ {
+ dataSize: 1000000,
+ expectedLevelStart: []int64{0, 2, 3},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ expectedLevelStart: []int64{0, 32, 33},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ s := MakeSize(tc.dataSize)
+ if s.blockSize != int64(usermem.PageSize) {
+ t.Errorf("got blockSize %d, want %d", s.blockSize, usermem.PageSize)
+ }
+ if s.digestSize != sha256DigestSize {
+ t.Errorf("got digestSize %d, want %d", s.digestSize, sha256DigestSize)
+ }
+ if len(s.levelStart) != len(tc.expectedLevelStart) {
+ t.Errorf("got levels %d, want %d", len(s.levelStart), len(tc.expectedLevelStart))
+ }
+ for i := 0; i < len(s.levelStart) && i < len(tc.expectedLevelStart); i++ {
+ if s.levelStart[i] != tc.expectedLevelStart[i] {
+ t.Errorf("got levelStart[%d] %d, want %d", i, s.levelStart[i], tc.expectedLevelStart[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGenerate(t *testing.T) {
+ // The input data has size dataSize. It starts with the data in startWith,
+ // and all other bytes are zeroes.
+ testCases := []struct {
+ dataSize int
+ startWith []byte
+ expectedRoot []byte
+ }{
+ {
+ dataSize: usermem.PageSize,
+ startWith: nil,
+ expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ },
+ {
+ dataSize: 128*usermem.PageSize + 1,
+ startWith: nil,
+ expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'a'},
+ expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'1'},
+ expectedRoot: []byte{74, 35, 103, 179, 176, 149, 254, 112, 42, 65, 104, 66, 119, 56, 133, 124, 228, 15, 65, 161, 150, 0, 117, 174, 242, 34, 115, 115, 218, 37, 3, 105},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ var (
+ data bytes.Buffer
+ tree bytes.Buffer
+ )
+
+ startSize := len(tc.startWith)
+ _, err := data.Write(tc.startWith)
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+ _, err = data.Write(make([]byte, tc.dataSize-startSize))
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+
+ root, err := Generate(&data, int64(tc.dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ if !bytes.Equal(root, tc.expectedRoot) {
+ t.Errorf("Unexpected root")
+ }
+ })
+ }
+}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index fdfa83648..60cf94fa1 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -482,10 +482,10 @@ func (cs *connState) handle(m message) (r message) {
defer func() {
if r == nil {
// Don't allow a panic to propagate.
- recover()
+ err := recover()
// Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ log.Warningf("panic in handler: %v\n%s", err, debug.Stack())
// Wrap in an EFAULT error; we don't really have a
// better way to describe this kind of error. It will
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 343f81f59..daba8b172 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -17,7 +17,6 @@
package arch
import (
- "encoding/binary"
"fmt"
"io"
@@ -49,9 +48,14 @@ const ARMTrapFlag = uint64(1) << 21
type aarch64FPState []byte
// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
func initAarch64FPState(data aarch64FPState) {
- binary.LittleEndian.PutUint32(data, fpsimdMagic)
- binary.LittleEndian.PutUint32(data[4:], fpsimdContextSize)
}
func newAarch64FPStateSlice() []byte {
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 0c9a62f0d..2c5d14be5 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -16,15 +16,12 @@ go_library(
],
deps = [
"//pkg/abi/linux",
- "//pkg/context",
"//pkg/fd",
- "//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fdimport",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
"//pkg/sentry/fs/user",
- "//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/host",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -36,7 +33,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 8767430b7..1bae7cfaf 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -25,13 +25,10 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fdimport"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/fs/user"
- "gvisor.dev/gvisor/pkg/sentry/fsbridge"
hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -39,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -107,6 +103,9 @@ type ExecArgs struct {
// String prints the arguments as a string.
func (args ExecArgs) String() string {
+ if len(args.Argv) == 0 {
+ return args.Filename
+ }
a := make([]string, len(args.Argv))
copy(a, args.Argv)
if args.Filename != "" {
@@ -179,37 +178,30 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
}
ctx := initArgs.NewContext(proc.Kernel)
- if initArgs.Filename == "" {
- if kernel.VFS2Enabled {
- // Get the full path to the filename from the PATH env variable.
- if initArgs.MountNamespaceVFS2 == nil {
- // Set initArgs so that 'ctx' returns the namespace.
- //
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
- initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
- }
- file, err := getExecutableFD(ctx, creds, proc.Kernel.VFS(), initArgs.MountNamespaceVFS2, initArgs.Envv, initArgs.WorkingDirectory, initArgs.Argv[0])
- if err != nil {
- return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in environment %v: %v", initArgs.Argv[0], initArgs.Envv, err)
- }
- initArgs.File = fsbridge.NewVFSFile(file)
- } else {
- if initArgs.MountNamespace == nil {
- // Set initArgs so that 'ctx' returns the namespace.
- initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
+ if kernel.VFS2Enabled {
+ // Get the full path to the filename from the PATH env variable.
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ }
+ } else {
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
- // initArgs must hold a reference on MountNamespace, which will
- // be donated to the new process in CreateProcess.
- initArgs.MountNamespace.IncRef()
- }
- f, err := user.ResolveExecutablePath(ctx, creds, initArgs.MountNamespace, initArgs.Envv, initArgs.WorkingDirectory, initArgs.Argv[0])
- if err != nil {
- return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], initArgs.Envv, err)
- }
- initArgs.Filename = f
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespace.IncRef()
}
}
+ resolved, err := user.ResolveExecutablePath(ctx, &initArgs)
+ if err != nil {
+ return nil, 0, nil, nil, err
+ }
+ initArgs.Filename = resolved
fds := make([]int, len(args.FilePayload.Files))
for i, file := range args.FilePayload.Files {
@@ -422,31 +414,3 @@ func ttyName(tty *kernel.TTY) string {
}
return fmt.Sprintf("pts/%d", tty.Index)
}
-
-// getExecutableFD resolves the given executable name and returns a
-// vfs.FileDescription for the executable file.
-func getExecutableFD(ctx context.Context, creds *auth.Credentials, vfsObj *vfs.VirtualFilesystem, mns *vfs.MountNamespace, envv []string, wd, name string) (*vfs.FileDescription, error) {
- path, err := user.ResolveExecutablePathVFS2(ctx, creds, mns, envv, wd, name)
- if err != nil {
- return nil, err
- }
-
- root := vfs.RootFromContext(ctx)
- defer root.DecRef()
-
- pop := vfs.PathOperation{
- Root: root,
- Start: root, // binPath is absolute, Start can be anything.
- Path: fspath.Parse(path),
- FollowFinalSymlink: true,
- }
- opts := &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- FileExec: true,
- }
- f, err := vfsObj.OpenAt(ctx, creds, &pop, opts)
- if err == syserror.ENOENT || err == syserror.EACCES {
- return nil, nil
- }
- return f, err
-}
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
index c7e197691..af66fe4dc 100644
--- a/pkg/sentry/devices/memdev/full.go
+++ b/pkg/sentry/devices/memdev/full.go
@@ -42,6 +42,7 @@ type fullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
index 33d060d02..92d3d71be 100644
--- a/pkg/sentry/devices/memdev/null.go
+++ b/pkg/sentry/devices/memdev/null.go
@@ -43,6 +43,7 @@ type nullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
index acfa23149..6b81da5ef 100644
--- a/pkg/sentry/devices/memdev/random.go
+++ b/pkg/sentry/devices/memdev/random.go
@@ -48,6 +48,7 @@ type randomFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// off is the "file offset". off is accessed using atomic memory
// operations.
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 3b1372b9e..c6f15054d 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -44,6 +44,7 @@ type zeroFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 846252c89..ca41520b4 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -146,7 +146,7 @@ func (f *File) DecRef() {
f.DecRefWithDestructor(func() {
// Drop BSD style locks.
lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
- f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng)
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
// Release resources held by the FileOperations.
f.FileOperations.Release()
@@ -310,7 +310,6 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error
if !f.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
-
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
// Handle append mode.
if f.Flags().Append {
@@ -355,7 +354,6 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
// offset."
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
defer unlockAppendMu()
-
if f.Flags().Append {
if err := f.offsetForAppend(ctx, &offset); err != nil {
return 0, err
@@ -374,9 +372,10 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
return f.FileOperations.Write(ctx, f, src, offset)
}
-// offsetForAppend sets the given offset to the end of the file.
+// offsetForAppend atomically sets the given offset to the end of the file.
//
-// Precondition: the file.Dirent.Inode.appendMu mutex should be held for writing.
+// Precondition: the file.Dirent.Inode.appendMu mutex should be held for
+// writing.
func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
uattr, err := f.Dirent.Inode.UnstableAttr(ctx)
if err != nil {
@@ -386,7 +385,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
}
// Update the offset.
- *offset = uattr.Size
+ atomic.StoreInt64(offset, uattr.Size)
return nil
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index bdba6efe5..d2dbff268 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -42,9 +42,10 @@
// Dirent.dirMu
// Dirent.mu
// DirentCache.mu
-// Locks in InodeOperations implementations or overlayEntry
// Inode.Watches.mu (see `Inotify` for other lock ordering)
// MountSource.mu
+// Inode.appendMu
+// Locks in InodeOperations implementations or overlayEntry
//
// If multiple Dirent or MountSource locks must be taken, locks in the parent must be
// taken before locks in their children.
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index a016c896e..51d7368a1 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -640,7 +640,7 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
- if !i.session().cachePolicy.cacheUAttrs(inode) {
+ if inode.MountSource.Flags.ReadOnly || !i.session().cachePolicy.cacheUAttrs(inode) {
return nil
}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 62f1246aa..fbfba1b58 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -368,6 +368,9 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if inode.MountSource.Flags.ReadOnly {
+ return nil
+ }
// Have we been using host kernel metadata caches?
if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
// Then the metadata is already up to date on the host.
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 926538d90..8a5d9c7eb 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -62,7 +62,7 @@ import (
type LockType int
// UniqueID is a unique identifier of the holder of a regional file lock.
-type UniqueID uint64
+type UniqueID interface{}
const (
// ReadLock describes a POSIX regional file lock to be taken
@@ -98,12 +98,7 @@ type Lock struct {
// If len(Readers) > 0 then HasWriter must be false.
Readers map[UniqueID]bool
- // HasWriter indicates that this is a write lock held by a single
- // UniqueID.
- HasWriter bool
-
- // Writer is only valid if HasWriter is true. It identifies a
- // single write lock holder.
+ // Writer holds the writer unique ID. It's nil if there are no writers.
Writer UniqueID
}
@@ -186,7 +181,6 @@ func makeLock(uid UniqueID, t LockType) Lock {
case ReadLock:
value.Readers[uid] = true
case WriteLock:
- value.HasWriter = true
value.Writer = uid
default:
panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
@@ -196,10 +190,7 @@ func makeLock(uid UniqueID, t LockType) Lock {
// isHeld returns true if uid is a holder of Lock.
func (l Lock) isHeld(uid UniqueID) bool {
- if l.HasWriter && l.Writer == uid {
- return true
- }
- return l.Readers[uid]
+ return l.Writer == uid || l.Readers[uid]
}
// lock sets uid as a holder of a typed lock on Lock.
@@ -214,20 +205,20 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// We cannot downgrade a write lock to a read lock unless the
// uid is the same.
- if l.HasWriter {
+ if l.Writer != nil {
if l.Writer != uid {
panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
}
// Ensure that there is only one reader if upgrading.
l.Readers = make(map[UniqueID]bool)
// Ensure that there is no longer a writer.
- l.HasWriter = false
+ l.Writer = nil
}
l.Readers[uid] = true
return
case WriteLock:
// If we are already the writer, then this is a no-op.
- if l.HasWriter && l.Writer == uid {
+ if l.Writer == uid {
return
}
// We can only upgrade a read lock to a write lock if there
@@ -243,7 +234,6 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// Ensure that there is only a writer.
l.Readers = make(map[UniqueID]bool)
- l.HasWriter = true
l.Writer = uid
default:
panic(fmt.Sprintf("lock: invalid lock type %d", t))
@@ -277,9 +267,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
switch t {
case ReadLock:
return l.lockable(r, func(value Lock) bool {
- // If there is no writer, there's no problem adding
- // another reader.
- if !value.HasWriter {
+ // If there is no writer, there's no problem adding another reader.
+ if value.Writer == nil {
return true
}
// If there is a writer, then it must be the same uid
@@ -289,10 +278,9 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
case WriteLock:
return l.lockable(r, func(value Lock) bool {
// If there are only readers.
- if !value.HasWriter {
- // Then this uid can only take a write lock if
- // this is a private upgrade, meaning that the
- // only reader is uid.
+ if value.Writer == nil {
+ // Then this uid can only take a write lock if this is a private
+ // upgrade, meaning that the only reader is uid.
return len(value.Readers) == 1 && value.Readers[uid]
}
// If the uid is already a writer on this region, then
@@ -304,7 +292,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
}
}
-// lock returns true if uid took a lock of type t on the entire range of LockRange.
+// lock returns true if uid took a lock of type t on the entire range of
+// LockRange.
//
// Preconditions: r.Start <= r.End (will panic otherwise).
func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
@@ -339,7 +328,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
seg, _ = l.SplitUnchecked(seg, r.End)
}
- // Set the lock on the segment. This is guaranteed to
+ // Set the lock on the segment. This is guaranteed to
// always be safe, given canLock above.
value := seg.ValuePtr()
value.lock(uid, t)
@@ -386,7 +375,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
value := seg.Value()
var remove bool
- if value.HasWriter && value.Writer == uid {
+ if value.Writer == uid {
// If we are unlocking a writer, then since there can
// only ever be one writer and no readers, then this
// lock should always be removed from the set.
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
index 8a3ace0c1..50a16e662 100644
--- a/pkg/sentry/fs/lock/lock_set_functions.go
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -44,14 +44,9 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
return Lock{}, false
}
}
- if val1.HasWriter != val2.HasWriter {
+ if val1.Writer != val2.Writer {
return Lock{}, false
}
- if val1.HasWriter {
- if val1.Writer != val2.Writer {
- return Lock{}, false
- }
- }
return val1, true
}
@@ -62,7 +57,6 @@ func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock)
for k, v := range val.Readers {
val0.Readers[k] = v
}
- val0.HasWriter = val.HasWriter
val0.Writer = val.Writer
return val, val0
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
index ba002aeb7..fad90984b 100644
--- a/pkg/sentry/fs/lock/lock_test.go
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -42,9 +42,6 @@ func equals(e0, e1 []entry) bool {
if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) {
return false
}
- if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter {
- return false
- }
if e0[i].Lock.Writer != e1[i].Lock.Writer {
return false
}
@@ -105,7 +102,7 @@ func TestCanLock(t *testing.T) {
LockRange: LockRange{2048, 3072},
},
{
- Lock: Lock{HasWriter: true, Writer: 1},
+ Lock: Lock{Writer: 1},
LockRange: LockRange{3072, 4096},
},
})
@@ -241,7 +238,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -254,7 +251,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -273,7 +270,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -301,7 +298,7 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
{
@@ -318,7 +315,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -550,7 +547,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -594,7 +591,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 3072},
},
{
@@ -633,7 +630,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -663,11 +660,11 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -675,28 +672,30 @@ func TestSetLock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- success := l.lock(test.uid, test.lockType, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
+ r := LockRange{Start: test.start, End: test.end}
+ success := l.lock(test.uid, test.lockType, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
- if success != test.success {
- t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success)
- continue
- }
+ if success != test.success {
+ t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success)
+ return
+ }
- if success {
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
+ if success {
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
}
- }
+ })
}
}
@@ -782,7 +781,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -824,7 +823,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -837,7 +836,7 @@ func TestUnlock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -876,7 +875,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -889,7 +888,7 @@ func TestUnlock(t *testing.T) {
// 0 4096
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
},
@@ -906,7 +905,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -974,7 +973,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -991,7 +990,7 @@ func TestUnlock(t *testing.T) {
// 0 8 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 8},
},
{
@@ -1008,7 +1007,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1025,7 +1024,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 8192 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1041,19 +1040,21 @@ func TestUnlock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- l.unlock(test.uid, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
- }
+ r := LockRange{Start: test.start, End: test.end}
+ l.unlock(test.uid, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ })
}
}
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
index bd5dac373..66e949c95 100644
--- a/pkg/sentry/fs/user/BUILD
+++ b/pkg/sentry/fs/user/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index fbd4547a7..397e96045 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -31,7 +32,15 @@ import (
// ResolveExecutablePath resolves the given executable name given the working
// dir and environment.
-func ResolveExecutablePath(ctx context.Context, creds *auth.Credentials, mns *fs.MountNamespace, envv []string, wd, name string) (string, error) {
+func ResolveExecutablePath(ctx context.Context, args *kernel.CreateProcessArgs) (string, error) {
+ name := args.Filename
+ if len(name) == 0 {
+ if len(args.Argv) == 0 {
+ return "", fmt.Errorf("no filename or command provided")
+ }
+ name = args.Argv[0]
+ }
+
// Absolute paths can be used directly.
if path.IsAbs(name) {
return name, nil
@@ -40,6 +49,7 @@ func ResolveExecutablePath(ctx context.Context, creds *auth.Credentials, mns *fs
// Paths with '/' in them should be joined to the working directory, or
// to the root if working directory is not set.
if strings.IndexByte(name, '/') > 0 {
+ wd := args.WorkingDirectory
if wd == "" {
wd = "/"
}
@@ -49,10 +59,24 @@ func ResolveExecutablePath(ctx context.Context, creds *auth.Credentials, mns *fs
return path.Join(wd, name), nil
}
- // Otherwise, We must lookup the name in the paths, starting from the
- // calling context's root directory.
- paths := getPath(envv)
+ // Otherwise, We must lookup the name in the paths.
+ paths := getPath(args.Envv)
+ if kernel.VFS2Enabled {
+ f, err := resolveVFS2(ctx, args.Credentials, args.MountNamespaceVFS2, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+ }
+
+ f, err := resolve(ctx, args.MountNamespace, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+}
+func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name string) (string, error) {
root := fs.RootFromContext(ctx)
if root == nil {
// Caller has no root. Don't bother traversing anything.
@@ -95,30 +119,7 @@ func ResolveExecutablePath(ctx context.Context, creds *auth.Credentials, mns *fs
return "", syserror.ENOENT
}
-// ResolveExecutablePathVFS2 resolves the given executable name given the
-// working dir and environment.
-func ResolveExecutablePathVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, envv []string, wd, name string) (string, error) {
- // Absolute paths can be used directly.
- if path.IsAbs(name) {
- return name, nil
- }
-
- // Paths with '/' in them should be joined to the working directory, or
- // to the root if working directory is not set.
- if strings.IndexByte(name, '/') > 0 {
- if wd == "" {
- wd = "/"
- }
- if !path.IsAbs(wd) {
- return "", fmt.Errorf("working directory %q must be absolute", wd)
- }
- return path.Join(wd, name), nil
- }
-
- // Otherwise, We must lookup the name in the paths, starting from the
- // calling context's root directory.
- paths := getPath(envv)
-
+func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
root := mns.Root()
defer root.DecRef()
for _, p := range paths {
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 585764223..cf440dce8 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -23,6 +23,7 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/unimpl",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index c03c65445..9b0e0cca2 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -116,6 +117,8 @@ type rootInode struct {
kernfs.InodeNotSymlink
kernfs.OrderedChildren
+ locks lock.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -183,7 +186,7 @@ func (i *rootInode) masterClose(t *Terminal) {
// Open implements kernfs.Inode.Open.
func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 7a7ce5d81..1d22adbe3 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -34,6 +35,8 @@ type masterInode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks lock.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -55,6 +58,7 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf
inode: mi,
t: t,
}
+ fd.LockFD.Init(&mi.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
mi.DecRef()
return nil, err
@@ -85,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds
type masterFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
inode *masterInode
t *Terminal
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index 526cd406c..7fe475080 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -33,6 +34,8 @@ type slaveInode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks lock.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -51,6 +54,7 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs
fd := &slaveFileDescription{
inode: si,
}
+ fd.LockFD.Init(&si.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
si.DecRef()
return nil, err
@@ -91,6 +95,7 @@ func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds
type slaveFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
inode *slaveInode
}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index c573d7935..d12d78b84 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -37,6 +37,7 @@ type EventFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// queue is used to notify interested parties when the event object
// becomes readable or writable.
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index ff861d0fe..973fa0def 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -60,6 +60,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index a2d8c3ad6..8bb104ff0 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -58,15 +58,16 @@ var _ io.ReaderAt = (*blockMapFile)(nil)
// newBlockMapFile is the blockMapFile constructor. It initializes the file to
// physical blocks map with (at most) the first 12 (direct) blocks.
-func newBlockMapFile(regFile regularFile) (*blockMapFile, error) {
- file := &blockMapFile{regFile: regFile}
+func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
+ file := &blockMapFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
for i := uint(0); i < 4; i++ {
- file.coverage[i] = getCoverage(regFile.inode.blkSize, i)
+ file.coverage[i] = getCoverage(file.regFile.inode.blkSize, i)
}
- blkMap := regFile.inode.diskInode.Data()
+ blkMap := file.regFile.inode.diskInode.Data()
binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 181727ef7..6fa84e7aa 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -85,20 +85,6 @@ func (n *blkNumGen) next() uint32 {
// the inode covers and that is written to disk.
func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
- regFile := regularFile{
- inode: inode{
- fs: &filesystem{
- dev: bytes.NewReader(mockDisk),
- },
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: getMockBMFileFize(),
- },
- },
- blkSize: uint64(mockBMBlkSize),
- },
- }
-
var fileData []byte
blkNums := newBlkNumGen()
var data []byte
@@ -125,9 +111,20 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
- copy(regFile.inode.diskInode.Data(), data)
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: getMockBMFileFize(),
+ },
+ },
+ blkSize: uint64(mockBMBlkSize),
+ }
+ copy(args.diskInode.Data(), data)
- mockFile, err := newBlockMapFile(regFile)
+ mockFile, err := newBlockMapFile(args)
if err != nil {
t.Fatalf("newBlockMapFile failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 12b875c8f..43be6928a 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -54,16 +54,15 @@ type directory struct {
}
// newDirectory is the directory constructor.
-func newDirectory(inode inode, newDirent bool) (*directory, error) {
+func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
file := &directory{
- inode: inode,
childCache: make(map[string]*dentry),
childMap: make(map[string]*dirent),
}
- file.inode.impl = file
+ file.inode.init(args, file)
// Initialize childList by reading dirents from the underlying file.
- if inode.diskInode.Flags().Index {
+ if args.diskInode.Flags().Index {
// TODO(b/134676337): Support hash tree directories. Currently only the '.'
// and '..' entries are read in.
@@ -74,7 +73,7 @@ func newDirectory(inode inode, newDirent bool) (*directory, error) {
// The dirents are organized in a linear array in the file data.
// Extract the file data and decode the dirents.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -82,7 +81,7 @@ func newDirectory(inode inode, newDirent bool) (*directory, error) {
// buf is used as scratch space for reading in dirents from disk and
// unmarshalling them into dirent structs.
buf := make([]byte, disklayout.DirentSize)
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
for off, inc := uint64(0), uint64(0); off < size; off += inc {
toRead := size - off
if toRead > disklayout.DirentSize {
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 11dcc0346..c36225a7c 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -38,9 +38,10 @@ var _ io.ReaderAt = (*extentFile)(nil)
// newExtentFile is the extent file constructor. It reads the entire extent
// tree into memory.
// TODO(b/134676337): Build extent tree on demand to reduce memory usage.
-func newExtentFile(regFile regularFile) (*extentFile, error) {
- file := &extentFile{regFile: regFile}
+func newExtentFile(args inodeArgs) (*extentFile, error) {
+ file := &extentFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
err := file.buildExtTree()
if err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index a2382daa3..cd10d46ee 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -177,21 +177,19 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
t.Helper()
mockDisk := make([]byte, mockExtentBlkSize*10)
- mockExtentFile := &extentFile{
- regFile: regularFile{
- inode: inode{
- fs: &filesystem{
- dev: bytes.NewReader(mockDisk),
- },
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
- },
- },
- blkSize: mockExtentBlkSize,
+ mockExtentFile := &extentFile{}
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
},
},
+ blkSize: mockExtentBlkSize,
}
+ mockExtentFile.regFile.inode.init(args, &mockExtentFile.regFile)
fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize)
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index 92f7da40d..90b086468 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -26,6 +26,7 @@ import (
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
}
func (fd *fileDescription) filesystem() *filesystem {
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 485f86f4b..5caaf14ed 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -54,6 +55,8 @@ type inode struct {
// diskInode gives us access to the inode struct on disk. Immutable.
diskInode disklayout.Inode
+ locks lock.FileLocks
+
// This is immutable. The first field of the implementations must have inode
// as the first field to ensure temporality.
impl interface{}
@@ -115,7 +118,7 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
// Build the inode based on its type.
- inode := inode{
+ args := inodeArgs{
fs: fs,
inodeNum: inodeNum,
blkSize: blkSize,
@@ -124,19 +127,19 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
switch diskInode.Mode().FileType() {
case linux.ModeSymlink:
- f, err := newSymlink(inode)
+ f, err := newSymlink(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeRegular:
- f, err := newRegularFile(inode)
+ f, err := newRegularFile(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeDirectory:
- f, err := newDirectory(inode, fs.sb.IncompatibleFeatures().DirentFileType)
+ f, err := newDirectory(args, fs.sb.IncompatibleFeatures().DirentFileType)
if err != nil {
return nil, err
}
@@ -147,6 +150,21 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
}
+type inodeArgs struct {
+ fs *filesystem
+ inodeNum uint32
+ blkSize uint64
+ diskInode disklayout.Inode
+}
+
+func (in *inode) init(args inodeArgs, impl interface{}) {
+ in.fs = args.fs
+ in.inodeNum = args.inodeNum
+ in.blkSize = args.blkSize
+ in.diskInode = args.diskInode
+ in.impl = impl
+}
+
// open creates and returns a file description for the dentry passed in.
func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
@@ -157,6 +175,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
+ fd.LockFD.Init(&in.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -168,6 +187,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
return nil, syserror.EISDIR
}
var fd directoryFD
+ fd.LockFD.Init(&in.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -178,6 +198,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
return nil, syserror.ELOOP
}
var fd symlinkFD
+ fd.LockFD.Init(&in.locks)
fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index 30135ddb0..152036b2e 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -43,28 +43,19 @@ type regularFile struct {
// newRegularFile is the regularFile constructor. It figures out what kind of
// file this is and initializes the fileReader.
-func newRegularFile(inode inode) (*regularFile, error) {
- regFile := regularFile{
- inode: inode,
- }
-
- inodeFlags := inode.diskInode.Flags()
-
- if inodeFlags.Extents {
- file, err := newExtentFile(regFile)
+func newRegularFile(args inodeArgs) (*regularFile, error) {
+ if args.diskInode.Flags().Extents {
+ file, err := newExtentFile(args)
if err != nil {
return nil, err
}
-
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
- file, err := newBlockMapFile(regFile)
+ file, err := newBlockMapFile(args)
if err != nil {
return nil, err
}
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
@@ -77,6 +68,7 @@ func (in *inode) isRegular() bool {
// vfs.FileDescriptionImpl.
type regularFileFD struct {
fileDescription
+ vfs.LockFD
// off is the file offset. off is accessed using atomic memory operations.
off int64
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index 1447a4dc1..acb28d85b 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -30,18 +30,17 @@ type symlink struct {
// newSymlink is the symlink constructor. It reads out the symlink target from
// the inode (however it might have been stored).
-func newSymlink(inode inode) (*symlink, error) {
- var file *symlink
+func newSymlink(args inodeArgs) (*symlink, error) {
var link []byte
// If the symlink target is lesser than 60 bytes, its stores in inode.Data().
// Otherwise either extents or block maps will be used to store the link.
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
if size < 60 {
- link = inode.diskInode.Data()[:size]
+ link = args.diskInode.Data()[:size]
} else {
// Create a regular file out of this inode and read out the target.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -52,8 +51,8 @@ func newSymlink(inode inode) (*symlink, error) {
}
}
- file = &symlink{inode: inode, target: string(link)}
- file.inode.impl = file
+ file := &symlink{target: string(link)}
+ file.inode.init(args, file)
return file, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index f5f35a3bc..5cdeeaeb5 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -54,6 +54,7 @@ go_library(
"//pkg/p9",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/host",
"//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
@@ -68,6 +69,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 36e0e1856..40933b74b 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -801,6 +801,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
return nil, err
}
fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
AllowDirectIO: true,
}); err != nil {
@@ -826,6 +827,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
}
}
fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -842,7 +844,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
}
case linux.S_IFIFO:
if d.isSynthetic() {
- return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags)
+ return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks)
}
}
return d.openSpecialFileLocked(ctx, mnt, opts)
@@ -902,7 +904,7 @@ retry:
return nil, err
}
}
- fd, err := newSpecialFileFD(h, mnt, d, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
@@ -989,6 +991,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
var childVFSFD *vfs.FileDescription
if useRegularFileFD {
fd := &regularFileFD{}
+ fd.LockFD.Init(&child.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
AllowDirectIO: true,
}); err != nil {
@@ -1003,7 +1006,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
if fdobj != nil {
h.fd = int32(fdobj.Release())
}
- fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 3f3bd56f0..0d88a328e 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -45,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -52,6 +53,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/usermem"
@@ -662,6 +664,8 @@ type dentry struct {
// If this dentry represents a synthetic named pipe, pipe is the pipe
// endpoint bound to this file.
pipe *pipe.VFSPipe
+
+ locks lock.FileLocks
}
// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
@@ -1366,6 +1370,9 @@ func (d *dentry) decLinks() {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ lockLogging sync.Once
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -1416,3 +1423,19 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption
func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name)
}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("File lock using gofer file handled internally.")
+ })
+ return fd.LockFD.LockBSD(ctx, uid, t, block)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("Range lock using gofer file handled internally.")
+ })
+ return fd.LockFD.LockPOSIX(ctx, uid, t, rng, block)
+}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index ff6126b87..289efdd25 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -51,7 +52,7 @@ type specialFileFD struct {
off int64
}
-func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) {
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks, flags uint32) (*specialFileFD, error) {
ftype := d.fileType()
seekable := ftype == linux.S_IFREG
mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK
@@ -60,6 +61,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci
seekable: seekable,
mayBlock: mayBlock,
}
+ fd.LockFD.Init(locks)
if mayBlock && h.fd >= 0 {
if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 2608e7e1d..1d5aa82dc 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -38,6 +38,9 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
// Preconditions: fs.interop != InteropModeShared.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index ca0fe6d2b..54f16ad63 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -39,6 +39,7 @@ go_library(
"//pkg/sentry/unimpl",
"//pkg/sentry/uniqueid",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 18b127521..5ec5100b8 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -34,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -182,6 +183,8 @@ type inode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks lock.FileLocks
+
// When the reference count reaches zero, the host fd is closed.
refs.AtomicRefCount
@@ -468,7 +471,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
return nil, err
}
// Currently, we only allow Unix sockets to be imported.
- return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d)
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks)
}
// TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
@@ -478,6 +481,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
fileDescription: fileDescription{inode: i},
termios: linux.DefaultSlaveTermios,
}
+ fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
@@ -486,6 +490,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
}
fd := &fileDescription{inode: i}
+ fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
@@ -497,6 +502,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
// inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
// cached to reduce indirections and casting. fileDescription does not hold
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index ef34cb28a..0299dbde9 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -49,6 +49,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
@@ -67,6 +68,7 @@ go_test(
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserror",
"//pkg/usermem",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 1568a9d49..6418de0a3 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -38,7 +39,8 @@ type DynamicBytesFile struct {
InodeNotDirectory
InodeNotSymlink
- data vfs.DynamicBytesSource
+ locks lock.FileLocks
+ data vfs.DynamicBytesSource
}
var _ Inode = (*DynamicBytesFile)(nil)
@@ -55,7 +57,7 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint
// Open implements Inode.Open.
func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &DynamicBytesFD{}
- if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -77,13 +79,15 @@ func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credent
type DynamicBytesFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
inode Inode
}
// Init initializes a DynamicBytesFD.
-func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) error {
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *lock.FileLocks, flags uint32) error {
+ fd.LockFD.Init(locks)
if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 8284e76a7..33a5968ca 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -42,6 +43,7 @@ import (
type GenericDirectoryFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
children *OrderedChildren
@@ -55,9 +57,9 @@ type GenericDirectoryFD struct {
// NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its
// dentry.
-func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
+func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
fd := &GenericDirectoryFD{}
- if err := fd.Init(children, opts); err != nil {
+ if err := fd.Init(children, locks, opts); err != nil {
return nil, err
}
if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
@@ -69,11 +71,12 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre
// Init initializes a GenericDirectoryFD. Use it when overriding
// GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the
// correct implementation.
-func (fd *GenericDirectoryFD) Init(children *OrderedChildren, opts *vfs.OpenOptions) error {
+func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) error {
if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
// Can't open directories for writing.
return syserror.EISDIR
}
+ fd.LockFD.Init(locks)
fd.children = children
return nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 982daa2e6..0e4927215 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -555,6 +556,8 @@ type StaticDirectory struct {
InodeAttrs
InodeNoDynamicLookup
OrderedChildren
+
+ locks lock.FileLocks
}
var _ Inode = (*StaticDirectory)(nil)
@@ -584,7 +587,7 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3
// Open implements kernfs.Inode.
func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
+ fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 412cf6ac9..6749facf7 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -100,8 +101,10 @@ type readonlyDir struct {
kernfs.InodeNotSymlink
kernfs.InodeNoDynamicLookup
kernfs.InodeDirectoryNoNewChildren
-
kernfs.OrderedChildren
+
+ locks lock.FileLocks
+
dentry kernfs.Dentry
}
@@ -117,7 +120,7 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod
}
func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
@@ -128,10 +131,12 @@ type dir struct {
attrs
kernfs.InodeNotSymlink
kernfs.InodeNoDynamicLookup
+ kernfs.OrderedChildren
+
+ locks lock.FileLocks
fs *filesystem
dentry kernfs.Dentry
- kernfs.OrderedChildren
}
func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
@@ -147,7 +152,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
}
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
index 5950a2d59..c618dbe6c 100644
--- a/pkg/sentry/fsimpl/pipefs/BUILD
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserror",
"//pkg/usermem",
],
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index cab771211..e4dabaa33 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -81,7 +82,8 @@ type inode struct {
kernfs.InodeNotSymlink
kernfs.InodeNoopRefCount
- pipe *pipe.VFSPipe
+ locks lock.FileLocks
+ pipe *pipe.VFSPipe
ino uint64
uid auth.KUID
@@ -147,7 +149,7 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.
// Open implements kernfs.Inode.Open.
func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags)
+ return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks)
}
// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 17c1342b5..351ba4ee9 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 36a911db4..e2cdb7ee9 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -37,6 +38,8 @@ type subtasksInode struct {
kernfs.OrderedChildren
kernfs.AlwaysValid
+ locks lock.FileLocks
+
fs *filesystem
task *kernel.Task
pidns *kernel.PIDNamespace
@@ -153,7 +156,7 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro
// Open implements kernfs.Inode.
func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &subtasksFD{task: i.task}
- if err := fd.Init(&i.OrderedChildren, &opts); err != nil {
+ if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil {
return nil, err
}
if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 482055db1..44078a765 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -38,6 +39,8 @@ type taskInode struct {
kernfs.InodeAttrs
kernfs.OrderedChildren
+ locks lock.FileLocks
+
task *kernel.Task
}
@@ -103,7 +106,7 @@ func (i *taskInode) Valid(ctx context.Context) bool {
// Open implements kernfs.Inode.
func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 44ccc9e4a..ef6c1d04f 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -53,6 +54,8 @@ func taskFDExists(t *kernel.Task, fd int32) bool {
}
type fdDir struct {
+ locks lock.FileLocks
+
fs *filesystem
task *kernel.Task
@@ -143,7 +146,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
// Open implements kernfs.Inode.
func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
@@ -270,7 +273,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry,
// Open implements kernfs.Inode.
func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 2f297e48a..e5eaa91cd 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -775,6 +776,8 @@ type namespaceInode struct {
kernfs.InodeNoopRefCount
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+
+ locks lock.FileLocks
}
var _ kernfs.Inode = (*namespaceInode)(nil)
@@ -791,6 +794,7 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32
func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &namespaceFD{inode: i}
i.IncRef()
+ fd.LockFD.Init(&i.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -801,6 +805,7 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *
// /proc/[pid]/ns/*.
type namespaceFD struct {
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
inode *namespaceInode
@@ -825,8 +830,3 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
func (fd *namespaceFD) Release() {
fd.inode.DecRef()
}
-
-// OnClose implements FileDescriptionImpl.
-func (*namespaceFD) OnClose(context.Context) error {
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index b51d43954..58c8b9d05 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -43,6 +44,8 @@ type tasksInode struct {
kernfs.OrderedChildren
kernfs.AlwaysValid
+ locks lock.FileLocks
+
fs *filesystem
pidns *kernel.PIDNamespace
@@ -197,7 +200,7 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
// Open implements kernfs.Inode.
func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index d29ef3f83..242ba9b5d 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -31,6 +31,7 @@ type SignalFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// target is the original signal target task.
//
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index a741e2bb6..237f17def 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 0af373604..b84463d3a 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -98,8 +99,10 @@ type dir struct {
kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.InodeDirectoryNoNewChildren
-
kernfs.OrderedChildren
+
+ locks lock.FileLocks
+
dentry kernfs.Dentry
}
@@ -121,7 +124,7 @@ func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.Set
// Open implements kernfs.Inode.Open.
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 60c92d626..2dc90d484 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -32,6 +32,7 @@ type TimerFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
events waiter.Queue
timer *ktime.Timer
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
index 83bf885ee..ac54d420d 100644
--- a/pkg/sentry/fsimpl/tmpfs/device_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -29,7 +29,7 @@ type deviceFile struct {
minor uint32
}
-func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
+func (fs *filesystem) newDeviceFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
file := &deviceFile{
kind: kind,
major: major,
@@ -43,7 +43,7 @@ func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode
default:
panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, kuid, kgid, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 70387cb9c..913b8a6c5 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -48,9 +48,9 @@ type directory struct {
childList dentryList
}
-func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *directory {
+func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *directory {
dir := &directory{}
- dir.inode.init(dir, fs, creds, linux.S_IFDIR|mode)
+ dir.inode.init(dir, fs, kuid, kgid, linux.S_IFDIR|mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
dir.dentry.inode = &dir.inode
dir.dentry.vfsd.Init(&dir.dentry)
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 183eb975c..72399b321 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -256,11 +256,12 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
if parentDir.inode.nlink == maxLinks {
return syserror.EMLINK
}
parentDir.inode.incLinksLocked() // from child's ".."
- childDir := fs.newDirectory(rp.Credentials(), opts.Mode)
+ childDir := fs.newDirectory(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
parentDir.insertChildLocked(&childDir.dentry, name)
return nil
})
@@ -269,18 +270,19 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
var childInode *inode
switch opts.Mode.FileType() {
case 0, linux.S_IFREG:
- childInode = fs.newRegularFile(rp.Credentials(), opts.Mode)
+ childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
case linux.S_IFIFO:
- childInode = fs.newNamedPipe(rp.Credentials(), opts.Mode)
+ childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
case linux.S_IFBLK:
- childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFCHR:
- childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFSOCK:
- childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint)
+ childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint)
default:
return syserror.EINVAL
}
@@ -355,7 +357,8 @@ afterTrailingSymlink:
}
defer rp.Mount().EndWrite()
// Create and open the child.
- child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
parentDir.insertChildLocked(child, name)
fd, err := child.open(ctx, rp, &opts, true)
if err != nil {
@@ -396,6 +399,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
switch impl := d.inode.impl.(type) {
case *regularFile:
var fd regularFileFD
+ fd.LockFD.Init(&d.inode.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -411,15 +415,16 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, syserror.EISDIR
}
var fd directoryFD
+ fd.LockFD.Init(&d.inode.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
case *symlink:
- // Can't open symlinks without O_PATH (which is unimplemented).
+ // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH.
return nil, syserror.ELOOP
case *namedPipe:
- return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags)
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks)
case *deviceFile:
return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
case *socketFile:
@@ -676,7 +681,8 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
- child := fs.newDentry(fs.newSymlink(rp.Credentials(), target))
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target))
parentDir.insertChildLocked(child, name)
return nil
})
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 8d77b3fa8..739350cf0 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -30,9 +30,9 @@ type namedPipe struct {
// Preconditions:
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
-func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
+func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
- file.inode.init(file, fs, creds, linux.S_IFIFO|mode)
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index fee174375..77447b32c 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -85,12 +84,12 @@ type regularFile struct {
size uint64
}
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
+func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
memFile: fs.memFile,
seals: linux.F_SEAL_SEAL,
}
- file.inode.init(file, fs, creds, linux.S_IFREG|mode)
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
@@ -366,28 +365,6 @@ func (fd *regularFileFD) Sync(ctx context.Context) error {
return nil
}
-// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
-func (fd *regularFileFD) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
- return fd.inode().lockBSD(uid, t, block)
-}
-
-// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
-func (fd *regularFileFD) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
- fd.inode().unlockBSD(uid)
- return nil
-}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
- return fd.inode().lockPOSIX(uid, t, rng, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
- fd.inode().unlockPOSIX(uid, rng)
- return nil
-}
-
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
file := fd.inode().impl.(*regularFile)
diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
index 25c2321af..3ed650474 100644
--- a/pkg/sentry/fsimpl/tmpfs/socket_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -26,9 +26,9 @@ type socketFile struct {
ep transport.BoundEndpoint
}
-func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+func (fs *filesystem) newSocketFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
file := &socketFile{ep: ep}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, kuid, kgid, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index 47e075ed4..b0de5fabe 100644
--- a/pkg/sentry/fsimpl/tmpfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -24,11 +24,11 @@ type symlink struct {
target string // immutable
}
-func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode {
+func (fs *filesystem) newSymlink(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, target string) *inode {
link := &symlink{
target: target,
}
- link.inode.init(link, fs, creds, linux.S_IFLNK|0777)
+ link.inode.init(link, fs, kuid, kgid, linux.S_IFLNK|mode)
link.inode.nlink = 1 // from parent directory
return &link.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index f0e098702..71a7522af 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -30,12 +30,12 @@ package tmpfs
import (
"fmt"
"math"
+ "strconv"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -112,6 +112,58 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ rootMode := linux.FileMode(0777)
+ if rootFileType == linux.S_IFDIR {
+ rootMode = 01777
+ }
+ modeStr, ok := mopts["mode"]
+ if ok {
+ delete(mopts, "mode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode & 07777)
+ }
+ rootKUID := creds.EffectiveKUID
+ uidStr, ok := mopts["uid"]
+ if ok {
+ delete(mopts, "uid")
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKUID = kuid
+ }
+ rootKGID := creds.EffectiveKGID
+ gidStr, ok := mopts["gid"]
+ if ok {
+ delete(mopts, "gid")
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKGID = kgid
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -127,11 +179,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
var root *dentry
switch rootFileType {
case linux.S_IFREG:
- root = fs.newDentry(fs.newRegularFile(creds, 0777))
+ root = fs.newDentry(fs.newRegularFile(rootKUID, rootKGID, rootMode))
case linux.S_IFLNK:
- root = fs.newDentry(fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget))
+ root = fs.newDentry(fs.newSymlink(rootKUID, rootKGID, rootMode, tmpfsOpts.RootSymlinkTarget))
case linux.S_IFDIR:
- root = &fs.newDirectory(creds, 01777).dentry
+ root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry
default:
fs.vfsfs.DecRef()
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
@@ -258,7 +310,6 @@ type inode struct {
ctime int64 // nanoseconds
mtime int64 // nanoseconds
- // Advisory file locks, which lock at the inode level.
locks lock.FileLocks
// Inotify watches for this inode.
@@ -269,15 +320,15 @@ type inode struct {
const maxLinks = math.MaxUint32
-func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) {
if mode.FileType() == 0 {
panic("file type is required in FileMode")
}
i.fs = fs
i.refs = 1
i.mode = uint32(mode)
- i.uid = uint32(creds.EffectiveKUID)
- i.gid = uint32(creds.EffectiveKGID)
+ i.uid = uint32(kuid)
+ i.gid = uint32(kgid)
i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
// Tmpfs creation sets atime, ctime, and mtime to current time.
now := fs.clock.Now().Nanoseconds()
@@ -486,44 +537,6 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return nil
}
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) lockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
- switch i.impl.(type) {
- case *regularFile:
- return i.locks.LockBSD(uid, t, block)
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) unlockBSD(uid fslock.UniqueID) error {
- switch i.impl.(type) {
- case *regularFile:
- i.locks.UnlockBSD(uid)
- return nil
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) lockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
- switch i.impl.(type) {
- case *regularFile:
- return i.locks.LockPOSIX(uid, t, rng, block)
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) unlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) error {
- switch i.impl.(type) {
- case *regularFile:
- i.locks.UnlockPOSIX(uid, rng)
- return nil
- }
- return syserror.EBADF
-}
-
// allocatedBlocksForSize returns the number of 512B blocks needed to
// accommodate the given size in bytes, as appropriate for struct
// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
@@ -562,6 +575,9 @@ func (i *inode) isDir() bool {
}
func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
@@ -652,6 +668,7 @@ func (i *inode) userXattrSupported() bool {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -731,8 +748,7 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s
// Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
// S_IRWXUGO.
- mode := linux.FileMode(0777)
- inode := fs.newRegularFile(creds, mode)
+ inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
rf := inode.impl.(*regularFile)
if allowSeals {
rf.seals = 0
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index e057d2c6d..6862f2ef5 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) {
}
return NoID, syserror.EPERM
}
+
+// SetUID translates the provided uid to the root user namespace and updates c's
+// uids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetUID(uid UID) error {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKUID = kuid
+ c.EffectiveKUID = kuid
+ c.SavedKUID = kuid
+ return nil
+}
+
+// SetGID translates the provided gid to the root user namespace and updates c's
+// gids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetGID(gid GID) error {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKGID = kgid
+ c.EffectiveKGID = kgid
+ c.SavedKGID = kgid
+ return nil
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index dbfcef0fa..b35afafe3 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -80,9 +80,6 @@ type FDTable struct {
refs.AtomicRefCount
k *Kernel
- // uid is a unique identifier.
- uid uint64
-
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -130,7 +127,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
// drop drops the table reference.
func (f *FDTable) drop(file *fs.File) {
// Release locks.
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lock.UniqueID(f.uid), lock.LockRange{0, lock.LockEOF})
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF})
// Send inotify events.
d := file.Dirent
@@ -164,17 +161,9 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
file.DecRef()
}
-// ID returns a unique identifier for this FDTable.
-func (f *FDTable) ID() uint64 {
- return f.uid
-}
-
// NewFDTable allocates a new FDTable that may be used by tasks in k.
func (k *Kernel) NewFDTable() *FDTable {
- f := &FDTable{
- k: k,
- uid: atomic.AddUint64(&k.fdMapUids, 1),
- }
+ f := &FDTable{k: k}
f.init()
return f
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 5efeb3767..bcbeb6a39 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -194,11 +194,6 @@ type Kernel struct {
// cpuClockTickerSetting is protected by runningTasksMu.
cpuClockTickerSetting ktime.Setting
- // fdMapUids is an ever-increasing counter for generating FDTable uids.
- //
- // fdMapUids is mutable, and is accessed using atomic memory operations.
- fdMapUids uint64
-
// uniqueID is used to generate unique identifiers.
//
// uniqueID is mutable, and is accessed using atomic memory operations.
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 7bfa9075a..0db546b98 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -27,6 +27,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 2602bed72..c0e9ee1f4 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -61,11 +62,13 @@ func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
//
// Preconditions: statusFlags should not contain an open access mode.
func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
- return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags)
+ // Connected pipes share the same locks.
+ locks := &lock.FileLocks{}
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
}
// Open opens the pipe represented by vp.
-func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) {
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) (*vfs.FileDescription, error) {
vp.mu.Lock()
defer vp.mu.Unlock()
@@ -75,7 +78,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
return nil, syserror.EINVAL
}
- fd := vp.newFD(mnt, vfsd, statusFlags)
+ fd := vp.newFD(mnt, vfsd, statusFlags, locks)
// Named pipes have special blocking semantics during open:
//
@@ -127,10 +130,11 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription {
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) *vfs.FileDescription {
fd := &VFSPipeFD{
pipe: &vp.pipe,
}
+ fd.LockFD.Init(locks)
fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
@@ -159,6 +163,7 @@ type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
pipe *Pipe
}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 00c425cca..9b69f3cbe 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -198,6 +198,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
+ oldFDTable := t.fdTable
+ t.fdTable = t.fdTable.Fork()
+ oldFDTable.DecRef()
+
// Remove FDs with the CloseOnExec flag set.
t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
return flags.CloseOnExec
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index da0ea7bb5..0adf25691 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -186,6 +186,7 @@ func (t *Timekeeper) startUpdater() {
timer := time.NewTicker(sentrytime.ApproxUpdateInterval)
t.wg.Add(1)
go func() { // S/R-SAFE: stopped during save.
+ defer t.wg.Done()
for {
// Start with an update immediately, so the clocks are
// ready ASAP.
@@ -220,7 +221,6 @@ func (t *Timekeeper) startUpdater() {
select {
case <-timer.C:
case <-t.stop:
- t.wg.Done()
return
}
}
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 00977fc08..165869028 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -75,7 +75,7 @@ var _ fs.FileOperations = (*byteReader)(nil)
// newByteReaderFile creates a fake file to read data from.
//
-// TODO(gvisor.dev/issue/1623): Convert to VFS2.
+// TODO(gvisor.dev/issue/2921): Convert to VFS2.
func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
// Create a fake inode.
inode := fs.NewInode(
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 1eeb9f317..a9836ba71 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -33,6 +33,7 @@ go_template_instance(
out = "usage_set.go",
consts = {
"minDegree": "10",
+ "trackGaps": "1",
},
imports = {
"platform": "gvisor.dev/gvisor/pkg/sentry/platform",
@@ -48,6 +49,26 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "reclaim_set",
+ out = "reclaim_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "pgalloc",
+ prefix = "reclaim",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "reclaimSetValue",
+ "Functions": "reclaimSetFunctions",
+ },
+)
+
go_library(
name = "pgalloc",
srcs = [
@@ -56,6 +77,7 @@ go_library(
"evictable_range_set.go",
"pgalloc.go",
"pgalloc_unsafe.go",
+ "reclaim_set.go",
"save_restore.go",
"usage_set.go",
],
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 2b11ea4ae..46f19d218 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -108,12 +108,6 @@ type MemoryFile struct {
usageSwapped uint64
usageLast time.Time
- // minUnallocatedPage is the minimum page that may be unallocated.
- // i.e., there are no unallocated pages below minUnallocatedPage.
- //
- // minUnallocatedPage is protected by mu.
- minUnallocatedPage uint64
-
// fileSize is the size of the backing memory file in bytes. fileSize is
// always a power-of-two multiple of chunkSize.
//
@@ -146,11 +140,9 @@ type MemoryFile struct {
// is protected by mu.
reclaimable bool
- // minReclaimablePage is the minimum page that may be reclaimable.
- // i.e., all reclaimable pages are >= minReclaimablePage.
- //
- // minReclaimablePage is protected by mu.
- minReclaimablePage uint64
+ // relcaim is the collection of regions for reclaim. relcaim is protected
+ // by mu.
+ reclaim reclaimSet
// reclaimCond is signaled (with mu locked) when reclaimable or destroyed
// transitions from false to true.
@@ -273,12 +265,10 @@ type evictableMemoryUserInfo struct {
}
const (
- chunkShift = 24
- chunkSize = 1 << chunkShift // 16 MB
+ chunkShift = 30
+ chunkSize = 1 << chunkShift // 1 GB
chunkMask = chunkSize - 1
- initialSize = chunkSize
-
// maxPage is the highest 64-bit page.
maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
)
@@ -302,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
if err := file.Truncate(0); err != nil {
return nil, err
}
- if err := file.Truncate(initialSize); err != nil {
- return nil, err
- }
f := &MemoryFile{
- opts: opts,
- fileSize: initialSize,
- file: file,
- // No pages are reclaimable. DecRef will always be able to
- // decrease minReclaimablePage from this point.
- minReclaimablePage: maxPage,
- evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ opts: opts,
+ file: file,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
}
- f.mappings.Store(make([]uintptr, initialSize/chunkSize))
+ f.mappings.Store(make([]uintptr, 0))
f.reclaimCond.L = &f.mu
if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
@@ -404,39 +387,29 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
alignment = usermem.HugePageSize
}
- start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment)
- end := start + length
- // File offsets are int64s. Since length must be strictly positive, end
- // cannot legitimately be 0.
- if end < start || int64(end) <= 0 {
+ // Find a range in the underlying file.
+ fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
+ if !ok {
return platform.FileRange{}, syserror.ENOMEM
}
- // Expand the file if needed. Double the file size on each expansion;
- // uncommitted pages have effectively no cost.
- fileSize := f.fileSize
- for int64(end) > fileSize {
- if fileSize >= 2*fileSize {
- // fileSize overflow.
- return platform.FileRange{}, syserror.ENOMEM
- }
- fileSize *= 2
- }
- if fileSize > f.fileSize {
- if err := f.file.Truncate(fileSize); err != nil {
+ // Expand the file if needed.
+ if int64(fr.End) > f.fileSize {
+ // Round the new file size up to be chunk-aligned.
+ newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
+ if err := f.file.Truncate(newFileSize); err != nil {
return platform.FileRange{}, err
}
- f.fileSize = fileSize
+ f.fileSize = newFileSize
f.mappingsMu.Lock()
oldMappings := f.mappings.Load().([]uintptr)
- newMappings := make([]uintptr, fileSize>>chunkShift)
+ newMappings := make([]uintptr, newFileSize>>chunkShift)
copy(newMappings, oldMappings)
f.mappings.Store(newMappings)
f.mappingsMu.Unlock()
}
// Mark selected pages as in use.
- fr := platform.FileRange{start, end}
if f.opts.ManualZeroing {
if err := f.forEachMappingSlice(fr, func(bs []byte) {
for i := range bs {
@@ -453,49 +426,71 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
}
- if minUnallocatedPage < start {
- f.minUnallocatedPage = minUnallocatedPage
- } else {
- // start was the first unallocated page. The next must be
- // somewhere beyond end.
- f.minUnallocatedPage = end
- }
-
return fr, nil
}
-// findUnallocatedRange returns the first unallocated page in usage of the
-// specified length and alignment beginning at page start and the first single
-// unallocated page.
-func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) {
- // Only searched until the first page is found.
- firstPage := start
- foundFirstPage := false
- alignMask := alignment - 1
- for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() {
- r := seg.Range()
+// findAvailableRange returns an available range in the usageSet.
+//
+// Note that scanning for available slots takes place from end first backwards,
+// then forwards. This heuristic has important consequence for how sequential
+// mappings can be merged in the host VMAs, given that addresses for both
+// application and sentry mappings are allocated top-down (from higher to
+// lower addresses). The file is also grown expoentially in order to create
+// space for mappings to be allocated downwards.
+//
+// Precondition: alignment must be a power of 2.
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) {
+ alignmentMask := alignment - 1
+ for gap := usage.UpperBoundGap(uint64(fileSize)); gap.Ok(); gap = gap.PrevLargeEnoughGap(length) {
+ // Start searching only at end of file.
+ end := gap.End()
+ if end > uint64(fileSize) {
+ end = uint64(fileSize)
+ }
- if !foundFirstPage && r.Start > firstPage {
- foundFirstPage = true
+ // Start at the top and align downwards.
+ start := end - length
+ if start > end {
+ break // Underflow.
}
+ start &^= alignmentMask
- if start >= r.End {
- // start was rounded up to an alignment boundary from the end
- // of a previous segment and is now beyond r.End.
+ // Is the gap still sufficient?
+ if start < gap.Start() {
continue
}
- // This segment represents allocated or reclaimable pages; only the
- // range from start to the segment's beginning is allocatable, and the
- // next allocatable range begins after the segment.
- if r.Start > start && r.Start-start >= length {
- break
+
+ // Allocate in the given gap.
+ return platform.FileRange{start, start + length}, true
+ }
+
+ // Check that it's possible to fit this allocation at the end of a file of any size.
+ min := usage.LastGap().Start()
+ min = (min + alignmentMask) &^ alignmentMask
+ if min+length < min {
+ // Overflow.
+ return platform.FileRange{}, false
+ }
+
+ // Determine the minimum file size required to fit this allocation at its end.
+ for {
+ if fileSize >= 2*fileSize {
+ // Is this because it's initially empty?
+ if fileSize == 0 {
+ fileSize += chunkSize
+ } else {
+ // fileSize overflow.
+ return platform.FileRange{}, false
+ }
+ } else {
+ // Double the current fileSize.
+ fileSize *= 2
}
- start = (r.End + alignMask) &^ alignMask
- if !foundFirstPage {
- firstPage = r.End
+ start := (uint64(fileSize) - length) &^ alignmentMask
+ if start >= min {
+ return platform.FileRange{start, start + length}, true
}
}
- return start, firstPage
}
// AllocateAndFill allocates memory of the given kind and fills it by calling
@@ -616,6 +611,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
val.refs--
if val.refs == 0 {
+ f.reclaim.Add(seg.Range(), reclaimSetValue{})
freed = true
// Reclassify memory as System, until it's freed by the reclaim
// goroutine.
@@ -628,10 +624,6 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
if freed {
- if fr.Start < f.minReclaimablePage {
- // We've freed at least one lower page.
- f.minReclaimablePage = fr.Start
- }
f.reclaimable = true
f.reclaimCond.Signal()
}
@@ -1030,6 +1022,7 @@ func (f *MemoryFile) String() string {
// for allocation.
func (f *MemoryFile) runReclaim() {
for {
+ // N.B. We must call f.markReclaimed on the returned FrameRange.
fr, ok := f.findReclaimable()
if !ok {
break
@@ -1085,6 +1078,10 @@ func (f *MemoryFile) runReclaim() {
}
}
+// findReclaimable finds memory that has been marked for reclaim.
+//
+// Note that there returned range will be removed from tracking. It
+// must be reclaimed (removed from f.usage) at this point.
func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
@@ -1103,18 +1100,15 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
f.reclaimCond.Wait()
}
- // Allocate returns the first usable range in offset order and is
- // currently a linear scan, so reclaiming from the beginning of the
- // file minimizes the expected latency of Allocate.
- for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() {
- if seg.ValuePtr().refs == 0 {
- f.minReclaimablePage = seg.End()
- return seg.Range(), true
- }
+ // Allocate works from the back of the file inwards, so reclaim
+ // preserves this order to minimize the cost of the search.
+ if seg := f.reclaim.LastSegment(); seg.Ok() {
+ fr := seg.Range()
+ f.reclaim.Remove(seg)
+ return fr, true
}
- // No pages are reclaimable.
+ // Nothing is reclaimable.
f.reclaimable = false
- f.minReclaimablePage = maxPage
}
}
@@ -1122,8 +1116,8 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
- // All of fr should be mapped to a single uncommitted reclaimable segment
- // accounted to System.
+ // All of fr should be mapped to a single uncommitted reclaimable
+ // segment accounted to System.
if !seg.Ok() {
panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
}
@@ -1137,14 +1131,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
}); got != want {
panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
}
- // Deallocate reclaimed pages. Even though all of seg is reclaimable, the
- // caller of markReclaimed may not have decommitted it, so we can only mark
- // fr as reclaimed.
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable,
+ // the caller of markReclaimed may not have decommitted it, so we can
+ // only mark fr as reclaimed.
f.usage.Remove(f.usage.Isolate(seg, fr))
- if fr.Start < f.minUnallocatedPage {
- // We've deallocated at least one lower page.
- f.minUnallocatedPage = fr.Start
- }
}
// StartEvictions requests that f evict all evictable allocations. It does not
@@ -1255,3 +1245,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal
func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
return evictableRangeSetValue{}, evictableRangeSetValue{}
}
+
+// reclaimSetValue is the value type of reclaimSet.
+type reclaimSetValue struct{}
+
+type reclaimSetFunctions struct{}
+
+func (reclaimSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (reclaimSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
+}
+
+func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+ return reclaimSetValue{}, true
+}
+
+func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+ return reclaimSetValue{}, reclaimSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
index 293f22c6b..b5b68eb52 100644
--- a/pkg/sentry/pgalloc/pgalloc_test.go
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -23,39 +23,49 @@ import (
const (
page = usermem.PageSize
hugepage = usermem.HugePageSize
+ topPage = (1 << 63) - page
)
func TestFindUnallocatedRange(t *testing.T) {
for _, test := range []struct {
- desc string
- usage *usageSegmentDataSlices
- start uint64
- length uint64
- alignment uint64
- unallocated uint64
- minUnallocated uint64
+ desc string
+ usage *usageSegmentDataSlices
+ fileSize int64
+ length uint64
+ alignment uint64
+ start uint64
+ expectFail bool
}{
{
- desc: "Initial allocation succeeds",
- usage: &usageSegmentDataSlices{},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ desc: "Initial allocation succeeds",
+ usage: &usageSegmentDataSlices{},
+ length: page,
+ alignment: page,
+ start: chunkSize - page, // Grows by chunkSize, allocate down.
},
{
- desc: "Allocation begins at start of file",
+ desc: "Allocation finds empty space at start of file",
usage: &usageSegmentDataSlices{
Start: []uint64{page},
End: []uint64{2 * page},
Values: []usageInfo{{refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocation finds empty space at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0},
+ End: []uint64{page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "In-use frames are not allocatable",
@@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 3 * page, // Double fileSize, allocate top-down.
},
{
desc: "Reclaimable frames are not allocatable",
@@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: 3 * page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: 5 * page, // Double fileSize, grow down.
},
{
desc: "Gaps between in-use frames are allocatable",
@@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "Inadequately-sized gaps are rejected",
@@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: 2 * page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: 2 * page,
+ alignment: page,
+ start: 4 * page, // Double fileSize, grow down.
},
{
- desc: "Hugepage alignment is honored",
+ desc: "Alignment is honored at end of file",
usage: &usageSegmentDataSlices{
Start: []uint64{0, hugepage + page},
// Hugepage-sized gap here that shouldn't be allocated from
@@ -118,37 +124,95 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, hugepage + 2*page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: hugepage,
- alignment: hugepage,
- unallocated: 2 * hugepage,
- minUnallocated: page,
+ fileSize: hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down.
+ },
+ {
+ desc: "Alignment is honored before end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2*hugepage + page},
+ // Page will need to be shifted down from top.
+ End: []uint64{page, 2*hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 2*hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: hugepage,
},
{
- desc: "Pages before start ignored",
+ desc: "Allocations are compact if possible",
usage: &usageSegmentDataSlices{
Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 4 * page,
+ length: page,
+ alignment: page,
+ start: 2 * page,
+ },
+ {
+ desc: "Top-down allocation within one gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 4 * page, 7 * page},
+ End: []uint64{2 * page, 5 * page, 8 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 6 * page,
+ },
+ {
+ desc: "Top-down allocation between multiple gaps",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page, 5 * page},
+ End: []uint64{2 * page, 4 * page, 6 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 6 * page,
+ length: page,
+ alignment: page,
+ start: 4 * page,
},
{
- desc: "start may be in the middle of segment",
+ desc: "Top-down allocation with large top gap",
usage: &usageSegmentDataSlices{
- Start: []uint64{0, 3 * page},
+ Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 7 * page,
+ },
+ {
+ desc: "Gaps found with possible overflow",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, topPage - page},
+ End: []uint64{2 * page, topPage},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: topPage,
+ length: page,
+ alignment: page,
+ start: topPage - 2*page,
+ },
+ {
+ desc: "Overflow detected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{topPage},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: topPage,
+ length: 2 * page,
+ alignment: page,
+ expectFail: true,
},
} {
t.Run(test.desc, func(t *testing.T) {
@@ -156,12 +220,18 @@ func TestFindUnallocatedRange(t *testing.T) {
if err := usage.ImportSortedSlices(test.usage); err != nil {
t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err)
}
- unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment)
- if unallocated != test.unallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated)
+ fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment)
+ if !test.expectFail && !ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if test.expectFail && ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.Start != test.start {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
}
- if minUnallocated != test.minUnallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated)
+ if ok && fr.End != test.start+test.length {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length)
}
})
}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 159f7eafd..4792454c4 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -6,8 +6,8 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
- "allocator.go",
"bluepill.go",
+ "bluepill_allocator.go",
"bluepill_amd64.go",
"bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index be213bfe8..faf1d5e1c 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -26,16 +26,15 @@ import (
// dirtySet tracks vCPUs for invalidation.
type dirtySet struct {
- vCPUs []uint64
+ vCPUMasks []uint64
}
// forEach iterates over all CPUs in the dirty set.
+//
+//go:nosplit
func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- for index := range ds.vCPUs {
- mask := atomic.SwapUint64(&ds.vCPUs[index], 0)
+ for index := range ds.vCPUMasks {
+ mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0)
if mask != 0 {
for bit := 0; bit < 64; bit++ {
if mask&(1<<uint64(bit)) == 0 {
@@ -54,7 +53,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
index := uint64(c.id) / 64
bit := uint64(1) << uint(c.id%64)
- oldValue := atomic.LoadUint64(&ds.vCPUs[index])
+ oldValue := atomic.LoadUint64(&ds.vCPUMasks[index])
if oldValue&bit != 0 {
return false // Not clean.
}
@@ -62,7 +61,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
// Set the bit unilaterally, and ensure that a flush takes place. Note
// that it's possible for races to occur here, but since the flush is
// taking place long after these lines there's no race in practice.
- atomicbitops.OrUint64(&ds.vCPUs[index], bit)
+ atomicbitops.OrUint64(&ds.vCPUMasks[index], bit)
return true // Previously clean.
}
@@ -113,7 +112,12 @@ type hostMapEntry struct {
length uintptr
}
-func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+// mapLocked maps the given host entry.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
for m.length > 0 {
physical, length, ok := translateToPhysical(m.addr)
if !ok {
@@ -133,18 +137,10 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// important; if the pagetable mappings were installed before
// ensuring the physical pages were available, then some other
// thread could theoretically access them.
- //
- // Due to the way KVM's shadow paging implementation works,
- // modifications to the page tables while in host mode may not
- // be trapped, leading to the shadow pages being out of sync.
- // Therefore, we need to ensure that we are in guest mode for
- // page table modifications. See the call to bluepill, below.
- as.machine.retryInGuest(func() {
- inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
- AccessType: at,
- User: true,
- }, physical) || inv
- })
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
m.addr += length
m.length -= length
addr += usermem.Addr(length)
@@ -176,6 +172,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return err
}
+ // See block in mapLocked.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+
// Map the mappings in the sentry's address space (guest physical memory)
// into the application's address space (guest virtual memory).
inv := false
@@ -190,7 +190,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
_ = s[i] // Touch to commit.
}
}
- prev := as.mapHost(addr, hostMapEntry{
+
+ // See bluepill_allocator.go.
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Perform the mapping.
+ prev := as.mapLocked(addr, hostMapEntry{
addr: b.Addr(),
length: uintptr(b.Len()),
}, at)
@@ -204,17 +209,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return nil
}
+// unmapLocked is an escape-checked wrapped around Unmap.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+ return as.pageTables.Unmap(addr, uintptr(length))
+}
+
// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
as.mu.Lock()
defer as.mu.Unlock()
- // See above re: retryInGuest.
- var prev bool
- as.machine.retryInGuest(func() {
- prev = as.pageTables.Unmap(addr, uintptr(length)) || prev
- })
- if prev {
+ // See above & bluepill_allocator.go.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ if prev := as.unmapLocked(addr, length); prev {
+ // Invalidate all active vCPUs.
as.invalidate()
// Recycle any freed intermediate pages.
@@ -227,7 +242,7 @@ func (as *addressSpace) Release() {
as.Unmap(0, ^uint64(0))
// Free all pages from the allocator.
- as.pageTables.Allocator.(allocator).base.Drain()
+ as.pageTables.Allocator.(*allocator).base.Drain()
// Drop all cached machine references.
as.machine.dropPageTables(as.pageTables)
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go
index 3f35414bb..9485e1301 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/bluepill_allocator.go
@@ -21,56 +21,80 @@ import (
)
type allocator struct {
- base *pagetables.RuntimeAllocator
+ base pagetables.RuntimeAllocator
+
+ // cpu must be set prior to any pagetable operation.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not be
+ // trapped, leading to the shadow pages being out of sync. Therefore,
+ // we need to ensure that we are in guest mode for page table
+ // modifications. See the call to bluepill, below.
+ cpu *vCPU
}
// newAllocator is used to define the allocator.
-func newAllocator() allocator {
- return allocator{
- base: pagetables.NewRuntimeAllocator(),
- }
+func newAllocator() *allocator {
+ a := new(allocator)
+ a.base.Init()
+ return a
}
// NewPTEs implements pagetables.Allocator.NewPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) NewPTEs() *pagetables.PTEs {
- return a.base.NewPTEs()
+func (a *allocator) NewPTEs() *pagetables.PTEs {
+ ptes := a.base.NewPTEs() // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+ return ptes
}
// PhysicalFor returns the physical address for a set of PTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
virtual := a.base.PhysicalFor(ptes)
physical, _, ok := translateToPhysical(virtual)
if !ok {
- panic(fmt.Sprintf("PhysicalFor failed for %p", ptes))
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic.
}
return physical
}
// LookupPTEs implements pagetables.Allocator.LookupPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
- panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic.
}
return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
}
// FreePTEs implements pagetables.Allocator.FreePTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) FreePTEs(ptes *pagetables.PTEs) {
- a.base.FreePTEs(ptes)
+func (a *allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes) // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
}
// Recycle implements pagetables.Allocator.Recycle.
//
//go:nosplit
-func (a allocator) Recycle() {
+func (a *allocator) Recycle() {
a.base.Recycle()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 133c2203d..ddc1554d5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -63,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -72,13 +74,15 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
@@ -89,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
// bluepillArchExit is called during bluepillEnter.
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 99cac665d..0d1e83e6c 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -53,3 +53,10 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
context.Rbx = uint64(uintptr(unsafe.Pointer(c)))
context.Rip = uint64(dieTrampolineAddr)
}
+
+// getHypercallID returns hypercall ID.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ return _KVM_HYPERCALL_MAX
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index c215d443c..83643c602 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -66,6 +66,8 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -88,6 +90,8 @@ func (c *vCPU) KernelSyscall() {
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 4ca2b7717..abd36f973 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -61,3 +61,16 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
func bluepillArchFpContext(context unsafe.Pointer) *arch.FpsimdContext {
return &((*arch.SignalContext64)(context).Fpsimd64)
}
+
+// getHypercallID returns hypercall ID.
+//
+// On Arm64, the MMIO address should be 64-bit aligned.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ if addr < arm64HypercallMMIOBase || addr >= (arm64HypercallMMIOBase+_AARCH64_HYPERCALL_MMIO_SIZE) {
+ return _KVM_HYPERCALL_MAX
+ } else {
+ return int(((addr) - arm64HypercallMMIOBase) >> 3)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 2407014e9..a5b9be36d 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -58,12 +58,32 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
return &((*arch.UContext64)(context).MContext)
}
+// bluepillHandleHlt is reponsible for handling VM-Exit.
+//
+//go:nosplit
+func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+}
+
// bluepillHandler is called from the signal stub.
//
// The world may be stopped while this is executing, and it executes on the
// signal stack. It should only execute raw system calls and functions that are
// explicitly marked go:nosplit.
//
+// +checkescape:all
+//
//go:nosplit
func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
@@ -82,7 +102,8 @@ func bluepillHandler(context unsafe.Pointer) {
}
for {
- switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno {
+ _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no.
+ switch errno {
case 0: // Expected case.
case syscall.EINTR:
// First, we process whatever pending signal
@@ -90,7 +111,7 @@ func bluepillHandler(context unsafe.Pointer) {
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
timeout := syscall.Timespec{}
- sig, _, errno := syscall.RawSyscall6(
+ sig, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
0, // siginfo.
@@ -125,7 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_NMI, 0); errno != 0 {
@@ -156,25 +177,20 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "debug")
return
case _KVM_EXIT_HLT:
- // Copy out registers.
- bluepillArchExit(c, bluepillArchContext(context))
-
- // Return to the vCPUReady state; notify any waiters.
- user := atomic.LoadUint32(&c.state) & vCPUUser
- switch atomic.SwapUint32(&c.state, user) {
- case user | vCPUGuest: // Expected case.
- case user | vCPUGuest | vCPUWaiter:
- c.notify()
- default:
- throw("invalid state")
- }
+ bluepillGuestExit(c, context)
return
case _KVM_EXIT_MMIO:
+ physical := uintptr(c.runData.data[0])
+ if getHypercallID(physical) == _KVM_HYPERCALL_VMEXIT {
+ bluepillGuestExit(c, context)
+ return
+ }
+
// Increment the fault count.
atomic.AddUint32(&c.faults, 1)
// For MMIO, the physical address is the first data item.
- physical := uintptr(c.runData.data[0])
+ physical = uintptr(c.runData.data[0])
virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
if !ok {
c.die(bluepillArchContext(context), "invalid physical address")
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 29d457a7e..3134a076b 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -26,6 +26,9 @@ type kvmOneReg struct {
addr uint64
}
+// arm64HypercallMMIOBase is MMIO base address used to dispatch hypercalls.
+var arm64HypercallMMIOBase uintptr
+
const KVM_NR_SPSR = 5
type userFpsimdState struct {
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 1d5c77ff4..6a7676468 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -71,3 +71,13 @@ const (
_KVM_MEM_READONLY = uint32(1) << 1
_KVM_MEM_FLAGS_NONE = 0
)
+
+// KVM hypercall list.
+// Canonical list of hypercalls supported.
+const (
+ // On amd64, it uses 'HLT' to leave the guest.
+ // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
+ // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ _KVM_HYPERCALL_VMEXIT int = iota
+ _KVM_HYPERCALL_MAX
+)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 531ae8b1e..6f0539c29 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -131,3 +131,10 @@ const (
_ESR_SEGV_PEMERR_L2 = 0xe
_ESR_SEGV_PEMERR_L3 = 0xf
)
+
+// Arm64: MMIO base address used to dispatch hypercalls.
+const (
+ // on Arm64, the MMIO address must be 64-bit aligned.
+ // Currently, we only need 1 hypercall: hypercall_vmexit.
+ _AARCH64_HYPERCALL_MMIO_SIZE = 1 << 3
+)
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index f1afc74dc..6c54712d1 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -52,16 +52,19 @@ type machine struct {
// available is notified when vCPUs are available.
available sync.Cond
- // vCPUs are the machine vCPUs.
+ // vCPUsByTID are the machine vCPUs.
//
// These are populated dynamically.
- vCPUs map[uint64]*vCPU
+ vCPUsByTID map[uint64]*vCPU
// vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
- vCPUsByID map[int]*vCPU
+ vCPUsByID []*vCPU
// maxVCPUs is the maximum number of vCPUs supported by the machine.
maxVCPUs int
+
+ // nextID is the next vCPU ID.
+ nextID uint32
}
const (
@@ -137,9 +140,8 @@ type dieState struct {
//
// Precondition: mu must be held.
func (m *machine) newVCPU() *vCPU {
- id := len(m.vCPUs)
-
// Create the vCPU.
+ id := int(atomic.AddUint32(&m.nextID, 1) - 1)
fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
if errno != 0 {
panic(fmt.Sprintf("error creating new vCPU: %v", errno))
@@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU {
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
- m := &machine{
- fd: vm,
- vCPUs: make(map[uint64]*vCPU),
- vCPUsByID: make(map[int]*vCPU),
- }
+ m := &machine{fd: vm}
m.available.L = &m.mu
m.kernel.Init(ring0.KernelOpts{
PageTables: pagetables.New(newAllocator()),
@@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) {
}
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+ // Create the vCPUs map/slices.
+ m.vCPUsByTID = make(map[uint64]*vCPU)
+ m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -274,6 +276,8 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
+//
+//go:nosplit
func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
_, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
@@ -304,7 +308,11 @@ func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
// Destroy vCPUs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
+ if c == nil {
+ continue
+ }
+
// Ensure the vCPU is not still running in guest mode. This is
// possible iff teardown has been done by other threads, and
// somehow a single thread has not executed any system calls.
@@ -337,7 +345,7 @@ func (m *machine) Get() *vCPU {
tid := procid.Current()
// Check for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.RUnlock()
return c
@@ -356,7 +364,7 @@ func (m *machine) Get() *vCPU {
tid = procid.Current()
// Recheck for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.Unlock()
return c
@@ -364,10 +372,10 @@ func (m *machine) Get() *vCPU {
for {
// Scan for an available vCPU.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -375,17 +383,17 @@ func (m *machine) Get() *vCPU {
}
// Create a new vCPU (maybe).
- if len(m.vCPUs) < m.maxVCPUs {
+ if int(m.nextID) < m.maxVCPUs {
c := m.newVCPU()
c.lock()
- m.vCPUs[tid] = c
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
}
// Scan for something not in user mode.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
continue
}
@@ -403,8 +411,8 @@ func (m *machine) Get() *vCPU {
}
// Steal the vCPU.
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -431,7 +439,7 @@ func (m *machine) Put(c *vCPU) {
// newDirtySet returns a new dirty set.
func (m *machine) newDirtySet() *dirtySet {
return &dirtySet{
- vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 923ce3909..acc823ba6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -51,9 +51,10 @@ func (m *machine) initArchState() error {
recover()
debug.SetPanicOnFault(old)
}()
- m.retryInGuest(func() {
- ring0.SetCPUIDFaulting(true)
- })
+ c := m.Get()
+ defer m.Put(c)
+ bluepill(c)
+ ring0.SetCPUIDFaulting(true)
return nil
}
@@ -89,8 +90,8 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
- if c.PCIDs != nil {
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
c.PCIDs.Drop(pt)
}
}
@@ -335,29 +336,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
}
}
-// retryInGuest runs the given function in guest mode.
-//
-// If the function does not complete in guest mode (due to execution of a
-// system call due to a GC stall, for example), then it will be retried. The
-// given function must be idempotent as a result of the retry mechanism.
-func (m *machine) retryInGuest(fn func()) {
- c := m.Get()
- defer m.Put(c)
- for {
- c.ClearErrorCode() // See below.
- bluepill(c) // Force guest mode.
- fn() // Execute the given function.
- _, user := c.ErrorCode()
- if user {
- // If user is set, then we haven't bailed back to host
- // mode via a kernel exception or system call. We
- // consider the full function to have executed in guest
- // mode and we can return.
- break
- }
- }
-}
-
// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
// There is no need to return read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 7156c245f..290f035dd 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -154,7 +154,7 @@ func (c *vCPU) setUserRegisters(uregs *userRegs) error {
//
//go:nosplit
func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_GET_REGS,
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 750751aa3..f3bf973de 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -106,7 +106,7 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
if c.PCIDs != nil {
c.PCIDs.Drop(pt)
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 3c02cef7c..48c834499 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -159,6 +159,10 @@ func (c *vCPU) initArchState() error {
return err
}
+ // Use the address of the exception vector table as
+ // the MMIO address base.
+ arm64HypercallMMIOBase = toLocation
+
data = ring0.PsrDefaultSet | ring0.KernelFlagsSet
reg.id = _KVM_ARM64_REGS_PSTATE
if err := c.setOneRegister(&reg); err != nil {
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index de7df4f80..9f86f6a7a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -115,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace {
//
//go:nosplit
func (c *vCPU) notify() {
- _, _, errno := syscall.RawSyscall6(
+ _, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_FUTEX,
uintptr(unsafe.Pointer(&c.state)),
linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index db6465663..2bc5f3ecd 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -362,9 +362,17 @@ mmio_exit:
MOVD R1, CPU_LAZY_VFP(RSV_REG)
VFP_DISABLE
- // MMIO_EXIT.
- MOVD $0, R9
- MOVD R0, 0xffff000000001000(R9)
+ // Trigger MMIO_EXIT/_KVM_HYPERCALL_VMEXIT.
+ //
+ // To keep it simple, I used the address of exception table as the
+ // MMIO base address, so that I can trigger a MMIO-EXIT by forcibly writing
+ // a read-only space.
+ // Also, the length is engough to match a sufficient number of hypercall ID.
+ // Then, in host user space, I can calculate this address to find out
+ // which hypercall.
+ MRS VBAR_EL1, R9
+ MOVD R0, 0x0(R9)
+
RET
// HaltAndResume halts execution and point the pointer to the resume function.
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 900c0bba7..021693791 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -31,23 +31,39 @@ type defaultHooks struct{}
// KernelSyscall implements Hooks.KernelSyscall.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelSyscall() { Halt() }
+func (defaultHooks) KernelSyscall() {
+ Halt()
+}
// KernelException implements Hooks.KernelException.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelException(Vector) { Halt() }
+func (defaultHooks) KernelException(Vector) {
+ Halt()
+}
// kernelSyscall is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() }
+func kernelSyscall(c *CPU) {
+ c.hooks.KernelSyscall()
+}
// kernelException is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) }
+func kernelException(c *CPU, vector Vector) {
+ c.hooks.KernelException(vector)
+}
// Init initializes a new CPU.
//
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 0feff8778..d37981dbf 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool {
//
// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
@@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
// Perform the switch.
swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
jumpToKernel() // Switch to upper half.
writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
@@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
jumpToUser() // Return to lower half.
- SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
index 23fd5c352..8d75b7599 100644
--- a/pkg/sentry/platform/ring0/pagetables/allocator.go
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -53,9 +53,14 @@ type RuntimeAllocator struct {
// NewRuntimeAllocator returns an allocator that uses runtime allocation.
func NewRuntimeAllocator() *RuntimeAllocator {
- return &RuntimeAllocator{
- used: make(map[*PTEs]struct{}),
- }
+ r := new(RuntimeAllocator)
+ r.Init()
+ return r
+}
+
+// Init initializes a RuntimeAllocator.
+func (r *RuntimeAllocator) Init() {
+ r.used = make(map[*PTEs]struct{})
}
// Recycle returns freed pages to the pool.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 87e88e97d..7f18ac296 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -86,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
if !opts.AccessType.Any() {
@@ -128,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
w := unmapWalker{
@@ -162,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
w := emptyWalker{
@@ -197,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false }
// Lookup returns the physical address for the given virtual address.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
mask := uintptr(usermem.PageSize - 1)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index e82d6cd1e..60c9896fc 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 677743113..027add1fd 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -35,6 +36,7 @@ import (
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
// We store metadata for hostinet sockets internally. Technically, we should
// access metadata (e.g. through stat, chmod) on the host for correctness,
@@ -59,6 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
fd: fd,
},
}
+ s.LockFD.Init(&lock.FileLocks{})
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 47ff48c00..66015e2bc 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
}
func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
- ipt := stk.IPTables()
- table, ok := ipt.Tables[tablename.String()]
+ table, ok := stk.IPTables().GetTable(tablename.String())
if !ok {
return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
-// FillDefaultIPTables sets stack's IPTables to the default tables and
-// populates them with metadata.
-func FillDefaultIPTables(stk *stack.Stack) {
- ipt := stack.DefaultTables()
-
- // In order to fill in the metadata, we have to translate ipt from its
- // netstack format to Linux's giant-binary-blob format.
- for name, table := range ipt.Tables {
- _, metadata, err := convertNetstackToBinary(name, table)
- if err != nil {
- panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+// FillIPTablesMetadata populates stack's IPTables with metadata.
+func FillIPTablesMetadata(stk *stack.Stack) {
+ stk.IPTables().ModifyTables(func(tables map[string]stack.Table) {
+ // In order to fill in the metadata, we have to translate ipt from its
+ // netstack format to Linux's giant-binary-blob format.
+ for name, table := range tables {
+ _, metadata, err := convertNetstackToBinary(name, table)
+ if err != nil {
+ panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ }
+ table.SetMetadata(metadata)
+ tables[name] = table
}
- table.SetMetadata(metadata)
- ipt.Tables[name] = table
- }
-
- stk.SetIPTables(ipt)
+ })
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
@@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
NumEntries: replace.NumEntries,
Size: replace.Size,
})
- ipt.Tables[replace.Name.String()] = table
- stk.SetIPTables(ipt)
+ stk.IPTables().ReplaceTable(replace.Name.String(), table)
return nil
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 3863293c7..1b4e0ad79 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
// Support only for OUTPUT chain.
// TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 57a1e1c12..4f98ee2d5 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
return false, false
}
- // Now we need the transport header. However, this may not have been set
- // yet.
- // TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the stack.Check codepath as matchers are
- // added.
- var tcpHeader header.TCP
- if pkt.TransportHeader != nil {
- tcpHeader = header.TCP(pkt.TransportHeader)
- } else {
- var length int
- if hook == stack.Prerouting {
- // The network header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // There's no valid TCP header here, so we hotdrop the
- // packet.
- return false, true
- }
- h := header.IPv4(hdr)
- pkt.NetworkHeader = hdr
- length = int(h.HeaderLength())
- }
- // The TCP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- // There's no valid TCP header here, so we hotdrop the
- // packet.
- return false, true
- }
- tcpHeader = header.TCP(hdr[length:])
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if len(tcpHeader) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we drop the packet immediately.
+ return false, true
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index cfa9e621d..3f20fc891 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
@@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
return false, false
}
- // Now we need the transport header. However, this may not have been set
- // yet.
- // TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the stack.Check codepath as matchers are
- // added.
- var udpHeader header.UDP
- if pkt.TransportHeader != nil {
- udpHeader = header.UDP(pkt.TransportHeader)
- } else {
- var length int
- if hook == stack.Prerouting {
- // The network header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // There's no valid UDP header here, so we hotdrop the
- // packet.
- return false, true
- }
- h := header.IPv4(hdr)
- pkt.NetworkHeader = hdr
- length = int(h.HeaderLength())
- }
- // The UDP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- // There's no valid UDP header here, so we hotdrop the
- // packet.
- return false, true
- }
- udpHeader = header.UDP(hdr[length:])
+ udpHeader := header.UDP(pkt.TransportHeader)
+ if len(udpHeader) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we drop the packet immediately.
+ return false, true
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 7212d8644..420e573c9 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index b854bf990..8bfee5193 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +41,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -66,7 +68,7 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
return nil, err
}
- return &SocketVFS2{
+ fd := &SocketVFS2{
socketOpsCommon: socketOpsCommon{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
@@ -75,7 +77,9 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
connection: connection,
sendBufferSize: defaultSendBufferSize,
},
- }, nil
+ }
+ fd.LockFD.Init(&lock.FileLocks{})
+ return fd, nil
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 333e0042e..0f592ecc3 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -37,6 +37,7 @@ go_library(
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
@@ -50,5 +51,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 60df51dae..738277391 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -33,6 +33,7 @@ import (
"syscall"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
@@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer s.EventUnregister(&e)
if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
return syserr.TranslateNetstackError(err)
}
@@ -1237,6 +1246,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1777,6 +1798,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
case linux.TCP_USER_TIMEOUT:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2106,30 +2138,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
switch name {
case linux.TCP_CONGESTION,
linux.TCP_CORK,
- linux.TCP_DEFER_ACCEPT,
linux.TCP_FASTOPEN,
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_KEEPCNT,
- linux.TCP_KEEPIDLE,
- linux.TCP_KEEPINTVL,
- linux.TCP_LINGER2,
- linux.TCP_MAXSEG,
linux.TCP_QUEUE_SEQ,
- linux.TCP_QUICKACK,
linux.TCP_REPAIR,
linux.TCP_REPAIR_QUEUE,
linux.TCP_REPAIR_WINDOW,
linux.TCP_SAVED_SYN,
linux.TCP_SAVE_SYN,
- linux.TCP_SYNCNT,
linux.TCP_THIN_DUPACK,
linux.TCP_THIN_LINEAR_TIMEOUTS,
linux.TCP_TIMESTAMP,
- linux.TCP_ULP,
- linux.TCP_USER_TIMEOUT,
- linux.TCP_WINDOW_CLAMP:
+ linux.TCP_ULP:
t.Kernel().EmitUnimplementedEvent(t)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index fcd8013c0..1412a4810 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,6 +39,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -64,6 +66,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
protocol: protocol,
},
}
+ s.LockFD.Init(&lock.FileLocks{})
vfsfd := &s.vfsfd
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index f5fa18136..9b44c2b89 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (stack.IPTables, error) {
+func (s *Stack) IPTables() (*stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
-// FillDefaultIPTables sets the stack's iptables to the default tables, which
-// allow and do not modify all traffic.
-func (s *Stack) FillDefaultIPTables() {
- netfilter.FillDefaultIPTables(s.Stack)
+// FillIPTablesMetadata populates stack's IPTables with metadata.
+func (s *Stack) FillIPTablesMetadata() {
+ netfilter.FillIPTablesMetadata(s.Stack)
}
// Resume implements inet.Stack.Resume.
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index de2cc4bdf..7d4cc80fe 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index ce5b94ee7..09c6d3b27 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -252,7 +252,7 @@ func (e *connectionedEndpoint) Close() {
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
if ce.Type() != e.stype {
- return syserr.ErrConnectionRefused
+ return syserr.ErrWrongProtocolForSocket
}
// Check if ce is e to avoid a deadlock.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 5b29e9d7f..4bb2b6ff4 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -417,7 +417,18 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer ep.Release()
// Connect the server endpoint.
- return s.ep.Connect(t, ep)
+ err = s.ep.Connect(t, ep)
+
+ if err == syserr.ErrWrongProtocolForSocket {
+ // Linux for abstract sockets returns ErrConnectionRefused
+ // instead of ErrWrongProtocolForSocket.
+ path, _ := extractPath(sockaddr)
+ if len(path) > 0 && path[0] == 0 {
+ err = syserr.ErrConnectionRefused
+ }
+ }
+
+ return err
}
// Write implements fs.FileOperations.Write.
@@ -448,15 +459,25 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
To: nil,
}
if len(to) > 0 {
- ep, err := extractEndpoint(t, to)
- if err != nil {
- return 0, err
- }
- defer ep.Release()
- w.To = ep
+ switch s.stype {
+ case linux.SOCK_SEQPACKET:
+ to = nil
+ case linux.SOCK_STREAM:
+ if s.State() == linux.SS_CONNECTED {
+ return 0, syserr.ErrAlreadyConnected
+ }
+ return 0, syserr.ErrNotSupported
+ default:
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
- if ep.Passcred() && w.Control.Credentials == nil {
- w.Control.Credentials = control.MakeCreds(t)
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
}
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 45e109361..8c32371a2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -39,6 +40,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -51,7 +53,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType)
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
- fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d)
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &lock.FileLocks{})
if err != nil {
return nil, syserr.FromError(err)
}
@@ -60,7 +62,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType)
// NewFileDescription creates and returns a socket file description
// corresponding to the given mount and dentry.
-func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) {
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *lock.FileLocks) (*vfs.FileDescription, error) {
// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
// SOCK_DGRAM and don't require CAP_NET_RAW.
if stype == linux.SOCK_RAW {
@@ -73,6 +75,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 35a98212a..8347617bd 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -998,9 +998,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
- // The lock uid is that of the Task's FDTable.
- lockUniqueID := lock.UniqueID(t.FDTable().ID())
-
// These locks don't block; execute the non-blocking operation using the inode's lock
// context directly.
switch flock.Type {
@@ -1010,12 +1007,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
@@ -1026,18 +1023,18 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
return 0, nil, nil
case linux.F_UNLCK:
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng)
return 0, nil, nil
default:
return 0, nil, syserror.EINVAL
@@ -2157,22 +2154,6 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
nonblocking := operation&linux.LOCK_NB != 0
operation &^= linux.LOCK_NB
- // flock(2):
- // Locks created by flock() are associated with an open file table entry. This means that
- // duplicate file descriptors (created by, for example, fork(2) or dup(2)) refer to the
- // same lock, and this lock may be modified or released using any of these descriptors. Furthermore,
- // the lock is released either by an explicit LOCK_UN operation on any of these duplicate
- // descriptors, or when all such descriptors have been closed.
- //
- // If a process uses open(2) (or similar) to obtain more than one descriptor for the same file,
- // these descriptors are treated independently by flock(). An attempt to lock the file using
- // one of these file descriptors may be denied by a lock that the calling process has already placed via
- // another descriptor.
- //
- // We use the File UniqueID as the lock UniqueID because it needs to reference the same lock across dup(2)
- // and fork(2).
- lockUniqueID := lock.UniqueID(file.UniqueID)
-
// A BSD style lock spans the entire file.
rng := lock.LockRange{
Start: 0,
@@ -2183,29 +2164,29 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.LOCK_EX:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_SH:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_UN:
- file.Dirent.Inode.LockCtx.BSD.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng)
default:
// flock(2): EINVAL operation is invalid.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 9c8b44f64..9f93f4354 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -14,8 +14,10 @@ go_library(
"getdents.go",
"inotify.go",
"ioctl.go",
+ "lock.go",
"memfd.go",
"mmap.go",
+ "mount.go",
"path.go",
"pipe.go",
"poll.go",
@@ -41,6 +43,7 @@ go_library(
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/eventfd",
"//pkg/sentry/fsimpl/pipefs",
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index ca0f7fd1e..6006758a5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -168,7 +168,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
err := tmpfs.AddSeals(file, args[2].Uint())
return 0, nil, err
default:
- // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ // TODO(gvisor.dev/issue/2920): Everything else is not yet supported.
return 0, nil, syserror.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
new file mode 100644
index 000000000..bf19028c4
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ operation := args[1].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ var blocker lock.Blocker
+ if !nonblocking {
+ blocker = t
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_SH:
+ if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_UN:
+ if err := file.UnlockBSD(t); err != nil {
+ return 0, nil, err
+ }
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
new file mode 100644
index 000000000..adeaa39cc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ // For null-terminated strings related to mount(2), Linux copies in at most
+ // a page worth of data. See fs/namespace.c:copy_mount_string().
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, err := copyInPath(t, targetAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the current mount namespace's associated user
+ // namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODEV |
+ linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var opts vfs.MountOptions
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ opts.Flags.NoATime = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ opts.Flags.NoExec = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ opts.ReadOnly = true
+ }
+ opts.GetFilesystemOptions.Data = data
+
+ target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer target.Release()
+
+ return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ opts := vfs.UmountOptions{
+ Flags: uint32(flags),
+ }
+
+ return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index ef8358b8a..954c82f97 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -62,7 +62,7 @@ func Override() {
s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
s.Table[59] = syscalls.Supported("execve", Execve)
s.Table[72] = syscalls.Supported("fcntl", Fcntl)
- delete(s.Table, 73) // flock
+ s.Table[73] = syscalls.Supported("fcntl", Flock)
s.Table[74] = syscalls.Supported("fsync", Fsync)
s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
s.Table[76] = syscalls.Supported("truncate", Truncate)
@@ -90,8 +90,8 @@ func Override() {
s.Table[138] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[161] = syscalls.Supported("chroot", Chroot)
s.Table[162] = syscalls.Supported("sync", Sync)
- delete(s.Table, 165) // mount
- delete(s.Table, 166) // umount2
+ s.Table[165] = syscalls.Supported("mount", Mount)
+ s.Table[166] = syscalls.Supported("umount2", Umount2)
delete(s.Table, 187) // readahead
s.Table[188] = syscalls.Supported("setxattr", Setxattr)
s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
diff --git a/pkg/sentry/time/muldiv_arm64.s b/pkg/sentry/time/muldiv_arm64.s
index 5ad57a8a3..8afc62d53 100644
--- a/pkg/sentry/time/muldiv_arm64.s
+++ b/pkg/sentry/time/muldiv_arm64.s
@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "funcdata.h"
#include "textflag.h"
// Documentation is available in parameters.go.
//
// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
TEXT ·muldiv64(SB),NOSPLIT,$40-33
+ GO_ARGS
+ NO_LOCAL_POINTERS
MOVD value+0(FP), R0
MOVD multiplier+8(FP), R1
MOVD divisor+16(FP), R2
diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go
index e1b9084ac..0ce1257f6 100644
--- a/pkg/sentry/time/parameters_test.go
+++ b/pkg/sentry/time/parameters_test.go
@@ -484,3 +484,18 @@ func TestMulDivOverflow(t *testing.T) {
})
}
}
+
+func BenchmarkMuldiv64(b *testing.B) {
+ var v uint64 = math.MaxUint64
+ for i := uint64(1); i <= 1000000; i++ {
+ mult := uint64(1000000000)
+ div := i * mult
+ res, ok := muldiv64(v, mult, div)
+ if !ok {
+ b.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div)
+ }
+ if want := v / i; res != want {
+ b.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want)
+ }
+ }
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 774cc66cc..16d9f3a28 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -72,6 +72,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/uniqueid",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 9aa133bcb..66f3105bd 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -39,8 +39,8 @@ Mount references are held by:
- Mount: Each referenced Mount holds a reference on its parent, which is the
mount containing its mount point.
-- VirtualFilesystem: A reference is held on each Mount that has not been
- umounted.
+- VirtualFilesystem: A reference is held on each Mount that has been connected
+ to a mount point, but not yet umounted.
MountNamespace and FileDescription references are held by users of VFS. The
expectation is that each `kernel.Task` holds a reference on its corresponding
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 8297f964b..599c3131c 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -31,6 +31,7 @@ type EpollInstance struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
DentryMetadataFileDescriptionImpl
+ NoLockFD
// q holds waiters on this EpollInstance.
q waiter.Queue
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index bb294563d..97b9b18d7 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -73,6 +73,8 @@ type FileDescription struct {
// writable is analogous to Linux's FMODE_WRITE.
writable bool
+ usedLockBSD uint32
+
// impl is the FileDescriptionImpl associated with this Filesystem. impl is
// immutable. This should be the last field in FileDescription.
impl FileDescriptionImpl
@@ -175,6 +177,12 @@ func (fd *FileDescription) DecRef() {
}
ep.interestMu.Unlock()
}
+
+ // If BSD locks were used, release any lock that it may have acquired.
+ if atomic.LoadUint32(&fd.usedLockBSD) != 0 {
+ fd.impl.UnlockBSD(context.Background(), fd)
+ }
+
// Release implementation resources.
fd.impl.Release()
if fd.writable {
@@ -420,13 +428,9 @@ type FileDescriptionImpl interface {
Removexattr(ctx context.Context, name string) error
// LockBSD tries to acquire a BSD-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): BSD-style file locking
LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
- // LockBSD releases a BSD-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): BSD-style file locking
+ // UnlockBSD releases a BSD-style advisory file lock.
UnlockBSD(ctx context.Context, uid lock.UniqueID) error
// LockPOSIX tries to acquire a POSIX-style advisory file lock.
@@ -736,3 +740,14 @@ func (fd *FileDescription) InodeID() uint64 {
func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
return fd.Sync(ctx)
}
+
+// LockBSD tries to acquire a BSD-style advisory file lock.
+func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error {
+ atomic.StoreUint32(&fd.usedLockBSD, 1)
+ return fd.impl.LockBSD(ctx, fd, lockType, blocker)
+}
+
+// UnlockBSD releases a BSD-style advisory file lock.
+func (fd *FileDescription) UnlockBSD(ctx context.Context) error {
+ return fd.impl.UnlockBSD(ctx, fd)
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index f4c111926..af7213dfd 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -21,8 +21,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -153,26 +154,6 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string)
return syserror.ENOTSUP
}
-// LockBSD implements FileDescriptionImpl.LockBSD.
-func (FileDescriptionDefaultImpl) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
- return syserror.EBADF
-}
-
-// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
-func (FileDescriptionDefaultImpl) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
- return syserror.EBADF
-}
-
-// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
-func (FileDescriptionDefaultImpl) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
- return syserror.EBADF
-}
-
-// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
-func (FileDescriptionDefaultImpl) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
- return syserror.EBADF
-}
-
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
@@ -384,3 +365,60 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M
fd.IncRef()
return nil
}
+
+// LockFD may be used by most implementations of FileDescriptionImpl.Lock*
+// functions. Caller must call Init().
+type LockFD struct {
+ locks *lock.FileLocks
+}
+
+// Init initializes fd with FileLocks to use.
+func (fd *LockFD) Init(locks *lock.FileLocks) {
+ fd.locks = locks
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return fd.locks.LockBSD(uid, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ fd.locks.UnlockBSD(uid)
+ return nil
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+ return fd.locks.LockPOSIX(uid, t, rng, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error {
+ fd.locks.UnlockPOSIX(uid, rng)
+ return nil
+}
+
+// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
+// returning ENOLCK.
+type NoLockFD struct{}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return syserror.ENOLCK
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error {
+ return syserror.ENOLCK
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 3a75d4d62..5061f6ac9 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -33,6 +33,7 @@ import (
type fileDescription struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
+ NoLockFD
}
// genCount contains the number of times its DynamicBytesSource.Generate()
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 286510195..8882fa84a 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -43,7 +43,7 @@ type Dentry struct {
// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is
// either d2's parent or an ancestor of d2's parent.
func IsAncestorDentry(d, d2 *Dentry) bool {
- for {
+ for d2 != nil { // Stop at root, where d2.parent == nil.
if d2.parent == d {
return true
}
@@ -52,6 +52,7 @@ func IsAncestorDentry(d, d2 *Dentry) bool {
}
d2 = d2.parent
}
+ return false
}
// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d.
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 05a3051a4..7fa7d2d0c 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -57,6 +57,7 @@ type Inotify struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
DentryMetadataFileDescriptionImpl
+ NoLockFD
// Unique identifier for this inotify instance. We don't just reuse the
// inotify fd because fds can be duped. These should not be exposed to the
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index e4ac6524b..32f901bd8 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -55,6 +55,10 @@ type Mount struct {
// ID is the immutable mount ID.
ID uint64
+ // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers". Immutable.
+ Flags MountFlags
+
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
// key.parent and key.point if they are not nil.
@@ -81,10 +85,6 @@ type Mount struct {
// umounted is true. umounted is protected by VirtualFilesystem.mountMu.
umounted bool
- // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
- // for MS_RDONLY which is tracked in "writers".
- flags MountFlags
-
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
@@ -95,10 +95,10 @@ type Mount struct {
func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
mnt := &Mount{
ID: atomic.AddUint64(&vfs.lastMountID, 1),
+ Flags: opts.Flags,
vfs: vfs,
fs: fs,
root: root,
- flags: opts.Flags,
ns: mntns,
refs: 1,
}
@@ -113,13 +113,12 @@ func (mnt *Mount) Options() MountOptions {
mnt.vfs.mountMu.Lock()
defer mnt.vfs.mountMu.Unlock()
return MountOptions{
- Flags: mnt.flags,
+ Flags: mnt.Flags,
ReadOnly: mnt.readOnly(),
}
}
-// A MountNamespace is a collection of Mounts.
-//
+// A MountNamespace is a collection of Mounts.//
// MountNamespaces are reference-counted. Unless otherwise specified, all
// MountNamespace methods require that a reference is held.
//
@@ -127,6 +126,9 @@ func (mnt *Mount) Options() MountOptions {
//
// +stateify savable
type MountNamespace struct {
+ // Owner is the usernamespace that owns this mount namespace.
+ Owner *auth.UserNamespace
+
// root is the MountNamespace's root mount. root is immutable.
root *Mount
@@ -163,6 +165,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return nil, err
}
mntns := &MountNamespace{
+ Owner: creds.UserNamespace,
refs: 1,
mountpoints: make(map[*Dentry]uint32),
}
@@ -265,8 +268,8 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
if err != nil {
return err
}
+ defer mnt.DecRef()
if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil {
- mnt.DecRef()
return err
}
return nil
@@ -279,6 +282,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
// MNT_FORCE is currently unimplemented except for the permission check.
+ // Force unmounting specifically requires CAP_SYS_ADMIN in the root user
+ // namespace, and not in the owner user namespace for the target mount. See
+ // fs/namespace.c:SYSCALL_DEFINE2(umount, ...)
if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
return syserror.EPERM
}
@@ -394,8 +400,15 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu
// references held by vd.
//
// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
-// writer critical section. d.mu must be locked. mnt.parent() == nil.
+// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt
+// must not already be connected.
func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) {
+ if checkInvariants {
+ if mnt.parent() != nil {
+ panic("VFS.connectLocked called on connected mount")
+ }
+ }
+ mnt.IncRef() // dropped by callers of umountRecursiveLocked
mnt.storeKey(vd)
if vd.mount.children == nil {
vd.mount.children = make(map[*Mount]struct{})
@@ -420,6 +433,11 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns
// writer critical section. mnt.parent() != nil.
func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
vd := mnt.loadKey()
+ if checkInvariants {
+ if vd.mount != nil {
+ panic("VFS.disconnectLocked called on disconnected mount")
+ }
+ }
mnt.storeKey(VirtualDentry{})
delete(vd.mount.children, mnt)
atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
@@ -741,7 +759,10 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
@@ -826,11 +847,12 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
- // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is
- // set.
fmt.Fprintf(buf, "%s ", opts)
// (7) Optional fields: zero or more fields of the form "tag[:value]".
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 53d364c5c..f223aeda8 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -75,6 +75,10 @@ type MknodOptions struct {
type MountFlags struct {
// NoExec is equivalent to MS_NOEXEC.
NoExec bool
+
+ // NoATime is equivalent to MS_NOATIME and indicates that the
+ // filesystem should not update access time in-place.
+ NoATime bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 52643a7c5..9acca8bc7 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -405,7 +405,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
vfs.putResolvingPath(rp)
if opts.FileExec {
- if fd.Mount().flags.NoExec {
+ if fd.Mount().Flags.NoExec {
fd.DecRef()
return nil, syserror.EACCES
}
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 26ce1077f..748273366 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -77,7 +77,10 @@ var DefaultOpts = Opts{
// trigger it.
const descheduleThreshold = 1 * time.Second
-var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+var (
+ stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout")
+ stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+)
// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
var stackDumpSameTaskPeriod = time.Minute
@@ -220,6 +223,9 @@ func (w *Watchdog) waitForStart() {
// We are fine.
return
}
+
+ stuckStartup.Increment()
+
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
w.doAction(w.StartupTimeoutAction, false, &buf)
@@ -323,7 +329,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:")
+ buf.WriteString("Watchdog goroutine is stuck")
w.doAction(w.TaskTimeoutAction, false, &buf)
}
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index e57d45f2a..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -22,7 +22,6 @@ go_test(
size = "small",
srcs = ["gonet_test.go"],
library = ":gonet",
- tags = ["flaky"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 76839eb92..62ac932bb 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
+// More returns whether the more fragments flag is set.
+func (b IPv4) More() bool {
+ return b.Flags()&IPv4FlagMoreFragments != 0
+}
+
// TTL returns the "TTL" field of the ipv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 2c4591409..3499d8399 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 {
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
}
+// IsAtomic returns whether the fragment header indicates an atomic fragment. An
+// atomic fragment is a fragment that contains all the data required to
+// reassemble a full packet.
+func (b IPv6FragmentExtHdr) IsAtomic() bool {
+ return !b.More() && b.FragmentOffset() == 0
+}
+
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
//
// The IPv6 payload may contain IPv6 extension headers before any upper layer
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 5eb78b398..20b183da0 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -181,12 +181,12 @@ func (e *Endpoint) NumQueued() int {
}
// InjectInbound injects an inbound packet.
-func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) {
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
@@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
route.Release()
p := PacketInfo{
- Pkt: &pkt,
+ Pkt: pkt,
Proto: protocol,
GSO: gso,
Route: route,
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 5ee508d48..f34082e1a 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -387,7 +387,7 @@ const (
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
@@ -641,7 +641,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
// InjectInbound injects an inbound packet.
-func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 6f41a71a8..eaee7e5d7 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,7 +45,7 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents stack.PacketBuffer
+ contents *stack.PacketBuffer
}
type context struct {
@@ -103,7 +103,7 @@ func (c *context) cleanup() {
}
}
-func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
c.ch <- packetInfo{remote, protocol, pkt}
}
@@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
Hash: hash,
@@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
Data: buffer.VectorisedView{},
}); err != nil {
@@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) {
want := packetInfo{
raddr: raddr,
proto: proto,
- contents: stack.PacketBuffer{
+ contents: &stack.PacketBuffer{
Data: buffer.View(b).ToVectorisedView(),
LinkHeader: buffer.View(hdr),
},
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index ca4229ed6..2dfd29aa9 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(remote, local, p, stack.PacketBuffer{
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{
Data: buffer.View(pkt).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 26c96a655..f04738cfb 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
LinkHeader: buffer.View(eth),
}
@@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
LinkHeader: buffer.View(eth),
}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 20d9e95f6..568c6874f 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
// Because we're immediately turning around and writing the packet back
// to the rx path, we intentionally don't preserve the remote and local
// link addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, stack.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
linkHeader := header.Ethernet(hdr)
vv.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{
Data: vv,
LinkHeader: buffer.View(linkHeader),
})
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index f0769830a..c69d6b7e9 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -80,7 +80,7 @@ func (m *InjectableEndpoint) IsAttached() bool {
}
// InjectInbound implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
@@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 87c734c1f..0744f66d6 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) {
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
@@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
hdr := buffer.NewPrependable(1)
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buffer.NewView(0).ToVectorisedView(),
})
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index ec5c5048a..b5dfb7850 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -102,7 +102,7 @@ func (q *queueDispatcher) dispatchLoop() {
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
-func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
newRoute := r.Clone()
@@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
- if !d.q.enqueue(&pkt) {
+ if !d.q.enqueue(pkt) {
return tcpip.ErrNoBufferSpace
}
d.newPacketWaker.Assert()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index f5dec0a7f..0374a2441 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// Add the ethernet header here.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
@@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b[:header.EthernetMinimumSize])
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{
+ d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{
Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index f3fc62607..28a2e88ba 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
@@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) {
randomFill(buf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
@@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) {
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
})
@@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: uu,
}); err != want {
@@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write the one-buffer packet again. It must succeed.
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b060d4627..f2e47b6a7 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -47,13 +47,6 @@ var LogPackets uint32 = 1
// LogPacketsToPCAP must be accessed atomically.
var LogPacketsToPCAP uint32 = 1
-var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{
- header.ICMPv4ProtocolNumber: header.IPv4MinimumSize,
- header.ICMPv6ProtocolNumber: header.IPv6MinimumSize,
- header.UDPProtocolNumber: header.UDPMinimumSize,
- header.TCPProtocolNumber: header.TCPMinimumSize,
-}
-
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
@@ -120,8 +113,8 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, &pkt)
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dumpPacket("recv", nil, protocol, pkt)
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -208,8 +201,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, &pkt)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.dumpPacket("send", gso, protocol, pkt)
return e.lower.WritePacket(r, gso, protocol, pkt)
}
@@ -287,7 +280,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
vv.TrimFront(header.ARPSize)
arp := header.ARP(hdr)
log.Infof(
- "%s arp %v (%v) -> %v (%v) valid:%v",
+ "%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
@@ -299,13 +292,6 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
return
}
- // We aren't guaranteed to have a transport header - it's possible for
- // writes via raw endpoints to contain only network headers.
- if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
- log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
- return
- }
-
// Figure out the transport layer info.
transName := "unknown"
srcPort := uint16(0)
@@ -346,7 +332,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -381,7 +367,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -428,7 +414,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
flagsStr[i] = ' '
}
}
- details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
@@ -437,7 +423,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
return
}
@@ -445,5 +431,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 617446ea2..6bc9033d0 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) {
remote = tcpip.LinkAddress(zeroMAC[:])
}
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.View(data).ToVectorisedView(),
}
if ethHdr != nil {
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index f5a05929f..949b3f2b2 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
@@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 0a9b99f18..63bf40562 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatchCount++
}
@@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
@@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) {
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 9d0797af7..7f27a840d 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -94,16 +94,12 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- v, ok := pkt.Data.PullUp(header.ARPSize)
- if !ok {
- return
- }
- h := header.ARP(v)
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.ARP(pkt.NetworkHeader)
if !h.IsValid() {
return
}
@@ -122,7 +118,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
fallthrough // also fill the cache from requests
@@ -177,7 +173,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
}
@@ -209,6 +205,17 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.ARPSize)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(header.ARPSize)
+ return 0, false, true
+}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 1646d9cde..66e67429c 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
Data: v.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index f42abc4bb..2982450f8 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
}
}
-// Process processes an incoming fragment belonging to an ID
-// and returns a complete packet when all the packets belonging to that ID have been received.
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet when all the packets belonging to that ID have been received.
func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4c20301c6..7c8fb3e0a 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
@@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
@@ -150,7 +150,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -289,9 +293,9 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -378,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: vv,
- })
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -444,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: frag1.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: frag2.ToVectorisedView(),
- })
+ pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -487,7 +488,11 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -530,9 +535,9 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -644,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view[:len(view)-c.trunc].ToVectorisedView(),
- })
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}
+
+// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
+// after truncation, is large enough to hold a network header, it makes part of
+// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
+// becomes Data.
+func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
+ v := view[:len(view)-trunc]
+ if len(v) < netHdrLen {
+ return &stack.PacketBuffer{Data: v.ToVectorisedView()}
+ }
+ return &stack.PacketBuffer{
+ NetworkHeader: v[:netHdrLen],
+ Data: v[netHdrLen:].ToVectorisedView(),
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 4cbefe5ab..1b67aa066 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -24,7 +24,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
@@ -56,9 +56,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
received.Invalid.Increment()
@@ -88,7 +91,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{
Data: pkt.Data.Clone(nil),
NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
})
@@ -102,7 +105,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: vv,
TransportHeader: buffer.View(pkt),
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 64046cbbf..7e9f16c90 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -21,6 +21,7 @@
package ipv4
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -129,7 +130,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
// packet's stated length matches the length of the header+payload. mtu
// includes the IP header and options. This does not support the DontFragment
// IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
@@ -169,7 +170,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
if i > 0 {
newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -188,7 +189,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayload := pkt.Data.Clone(nil)
newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -202,7 +203,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
startOfHdr := pkt.Header
startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: startOfHdr,
Data: emptyVV,
NetworkHeader: buffer.View(h),
@@ -245,7 +246,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -253,43 +254,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok {
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
return nil
}
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
+ // only NATted packets, but removing this check short circuits broadcasts
+ // before they are sent out to other hosts.
if pkt.NatDone {
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
netHeader := header.IPv4(pkt.NetworkHeader)
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- // The inbound path expects the network header to still be in
- // the PacketBuffer's Data field.
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
-
+ e.HandlePacket(&loopedR, pkt)
loopedR.Release()
}
if r.Loop&stack.PacketOut == 0 {
@@ -342,23 +329,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader)
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
- if err == nil {
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ ep.HandlePacket(&route, pkt)
n++
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
}
@@ -370,7 +350,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
@@ -426,35 +406,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv4(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- h := header.IPv4(headerView)
- if !h.IsValid(pkt.Data.Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
- pkt.NetworkHeader = headerView[:h.HeaderLength()]
-
- hlen := int(h.HeaderLength())
- tlen := int(h.TotalLength())
- pkt.Data.TrimFront(hlen)
- pkt.Data.CapLength(tlen - hlen)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok {
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
return
}
- more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
- if more || h.FragmentOffset() != 0 {
- if pkt.Data.Size() == 0 {
+ if h.More() || h.FragmentOffset() != 0 {
+ if pkt.Data.Size()+len(pkt.TransportHeader) == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -473,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
var ready bool
var err error
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -485,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- headerView.CapLength(hlen)
+ pkt.NetworkHeader.CapLength(int(h.HeaderLength()))
e.handleICMP(r, pkt)
return
}
@@ -565,6 +533,35 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // If there are options, pull those into hdr as well.
+ if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(headerLen)
+ if !ok {
+ panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen))
+ }
+ ipHdr = header.IPv4(hdr)
+ }
+
+ // If this is a fragment, don't bother parsing the transport header.
+ parseTransportHeader := true
+ if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
+ parseTransportHeader = false
+ }
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return ipHdr.TransportProtocol(), parseTransportHeader, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 36035c820..11e579c4b 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
+func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
@@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn
type errorChannel struct {
*channel.Endpoint
- Ch chan stack.PacketBuffer
+ Ch chan *stack.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -184,7 +184,7 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan stack.PacketBuffer, size),
+ Ch: make(chan *stack.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
@@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
select {
case e.Ch <- pkt:
default:
@@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := stack.PacketBuffer{
+ source := &stack.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []stack.PacketBuffer
+ var results []*stack.PacketBuffer
L:
for {
select {
@@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) {
s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
})
}
@@ -644,6 +652,18 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload1, udpPayload2},
},
+ {
+ name: "Fragment without followup",
+ fragments: []fragmentData{
+ {
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1[:64],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -698,7 +718,7 @@ func TestReceiveFragments(t *testing.T) {
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
- e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{
Data: vv,
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index bdf3a0d25..2ff7eedf4 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -27,7 +27,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return
@@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
- iph := header.IPv6(netHeader)
+ iph := header.IPv6(pkt.NetworkHeader)
// Validate ICMPv6 checksum before processing the packet.
//
@@ -288,7 +291,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
}); err != nil {
sent.Dropped.Increment()
@@ -390,7 +393,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: pkt.Data,
}); err != nil {
@@ -532,7 +535,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d412ff688..52a01b44e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -67,7 +67,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) {
},
}
- handleIPv6Payload := func(hdr buffer.Prependable) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ ep.HandlePacket(&r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
}
for _, typ := range types {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- handleIPv6Payload(hdr)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -328,7 +324,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{
Data: vv,
})
}
@@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
handleIPv6Payload := func(checksum bool) {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
}
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(typ.size + extraDataLen),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
})
}
@@ -740,7 +733,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -918,7 +911,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index daf1fcbc6..95fbcf2d1 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, stack.PacketBuffer{
+ e.HandlePacket(&loopedR, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -163,30 +163,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv6(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- h := header.IPv6(headerView)
- if !h.IsValid(pkt.Data.Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
-
- pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
- pkt.Data.TrimFront(header.IPv6MinimumSize)
- pkt.Data.CapLength(int(h.PayloadLength()))
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ // vv consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader)
+ vv.Append(pkt.Data)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
for firstHeader := true; ; firstHeader = false {
@@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6FragmentExtHdr:
hasFragmentHeader = true
- fragmentOffset := extHdr.FragmentOffset()
- more := extHdr.More()
- if !more && fragmentOffset == 0 {
+ if extHdr.IsAtomic() {
// This fragment extension header indicates that this packet is an
// atomic fragment. An atomic fragment is a fragment that contains
// all the data required to reassemble a full packet. As per RFC 6946,
@@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
- rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+ rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
- if fragmentOffset == 0 {
+ if extHdr.FragmentOffset() == 0 {
// Check that the iterator ends with a raw payload as the first fragment
// should include all headers up to and including any upper layer
// headers, as per RFC 8200 section 4.5; only upper layer data
@@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
// The packet is a fragment, let's try to reassemble it.
- start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
last := start + uint16(fragmentPayloadLen) - 1
// Drop the packet if the fragmentOffset is incorrect. i.e the
@@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
var ready bool
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ // Note that pkt doesn't have its transport header set after reassembly,
+ // and won't until DeliverNetworkPacket sets it.
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6RawPayloadHeader:
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
+
+ // For unfragmented packets, extHdr still contains the transport header.
+ // Get rid of it.
+ //
+ // For reassembled fragments, pkt.TransportHeader is unset, so this is a
+ // no-op and pkt.Data begins with the transport header.
+ extHdr.Buf.TrimFront(len(pkt.TransportHeader))
pkt.Data = extHdr.Buf
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt, hasFragmentHeader)
+ e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
// TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
@@ -505,6 +510,79 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ //
+ // Parsing occurs again in HandlePacket because we don't track the
+ // extensions in PacketBuffer. Unfortunately, that means HandlePacket
+ // has to do the parsing work again.
+ var nextHdr tcpip.TransportProtocolNumber
+ foundNext := true
+ extensionsSize := 0
+traverseExtensions:
+ for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
+ if err != nil {
+ break
+ }
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ foundNext = false
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ // If this is an atomic fragment, we don't have to treat it specially.
+ if !extHdr.More() && extHdr.FragmentOffset() == 0 {
+ continue
+ }
+ // This is a non-atomic fragment and has to be re-assembled before we can
+ // examine the payload for a transport header.
+ foundNext = false
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader.
+ hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+
+ return nextHdr, foundNext, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 841a0cb7a..213ff64f2 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
DstAddr: addr2,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
vv := hdr.View().ToVectorisedView()
vv.Append(f.data)
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: vv,
})
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 12b70f7e9..64239ce9a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) {
return s, ep, r
}
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
if atomicFragment {
- bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
- bytes[0] = nextHdr
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
}
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions)))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: payload.ToVectorisedView(),
})
}
@@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetCode(test.code)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
@@ -884,7 +885,7 @@ func TestRouterAdvertValidation(t *testing.T) {
t.Fatalf("got rxRA = %d, want = 0", got)
}
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b937cb84b..edc29ad27 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -54,17 +54,27 @@ type Flags struct {
LoadBalanced bool
}
-func (f Flags) bits() reuseFlag {
- var rf reuseFlag
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
if f.MostRecent {
- rf |= mostRecentFlag
+ rf |= MostRecentFlag
}
if f.LoadBalanced {
- rf |= loadBalancedFlag
+ rf |= LoadBalancedFlag
}
return rf
}
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
// PortManager manages allocating, reserving and releasing ports.
type PortManager struct {
mu sync.RWMutex
@@ -78,56 +88,88 @@ type PortManager struct {
hint uint32
}
-type reuseFlag int
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
const (
- mostRecentFlag reuseFlag = 1 << iota
- loadBalancedFlag
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
nextFlag
- flagMask = nextFlag - 1
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
)
-type portNode struct {
- // refs stores the count for each possible flag combination.
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
refs [nextFlag]int
}
-func (p portNode) totalRefs() int {
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
var total int
- for _, r := range p.refs {
+ for _, r := range c.refs {
total += r
}
return total
}
-// flagRefs returns the number of references with all specified flags.
-func (p portNode) flagRefs(flags reuseFlag) int {
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
var total int
- for i, r := range p.refs {
- if reuseFlag(i)&flags == flags {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
total += r
}
}
return total
}
-// allRefsHave returns if all references have all specified flags.
-func (p portNode) allRefsHave(flags reuseFlag) bool {
- for i, r := range p.refs {
- if reuseFlag(i)&flags == flags && r > 0 {
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
return false
}
}
return true
}
-// intersectionRefs returns the set of flags shared by all references.
-func (p portNode) intersectionRefs() reuseFlag {
- intersection := flagMask
- for i, r := range p.refs {
+// IntersectionRefs returns the set of flags shared by all references.
+func (c FlagCounter) IntersectionRefs() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
if r > 0 {
- intersection &= reuseFlag(i)
+ intersection &= BitFlags(i)
}
}
return intersection
@@ -135,26 +177,26 @@ func (p portNode) intersectionRefs() reuseFlag {
// deviceNode is never empty. When it has no elements, it is removed from the
// map that references it.
-type deviceNode map[tcpip.NICID]portNode
+type deviceNode map[tcpip.NICID]FlagCounter
// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all portNodes. If binding to a specific device, check
+// device, check against all FlagCounters. If binding to a specific device, check
// against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
- flagBits := flags.bits()
+ flagBits := flags.Bits()
if bindToDevice == 0 {
// Trying to binding all devices.
if flagBits == 0 {
// Can't bind because the (addr,port) is already bound.
return false
}
- intersection := flagMask
+ intersection := FlagMask
for _, p := range d {
- i := p.intersectionRefs()
+ i := p.IntersectionRefs()
intersection &= i
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
@@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
return true
}
- intersection := flagMask
+ intersection := FlagMask
if p, ok := d[0]; ok {
- intersection = p.intersectionRefs()
+ intersection = p.IntersectionRefs()
if intersection&flagBits == 0 {
return false
}
}
if p, ok := d[bindToDevice]; ok {
- i := p.intersectionRefs()
+ i := p.IntersectionRefs()
intersection &= i
if intersection&flagBits == 0 {
return false
@@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
return false
}
- flagBits := flags.bits()
+ flagBits := flags.Bits()
// Reserve port on all network protocols.
for _, network := range networks {
@@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m[addr] = d
}
n := d[bindToDevice]
- n.refs[flagBits]++
+ n.AddRef(flagBits)
d[bindToDevice] = n
}
@@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
s.mu.Lock()
defer s.mu.Unlock()
- flagBits := flags.bits()
+ flagBits := flags.Bits()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
@@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
}
n.refs[flagBits]--
d[bindToDevice] = n
- if n.refs == [nextFlag]int{} {
+ if n.TotalRefs() == 0 {
delete(d, bindToDevice)
}
if len(d) == 0 {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f71073207..24f52b735 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -89,6 +89,7 @@ go_test(
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -110,5 +111,6 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7d1ede1f2..05bf62788 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
@@ -147,46 +146,8 @@ type ConnTrackTable struct {
Seed uint32
}
-// parseHeaders sets headers in the packet.
-func parseHeaders(pkt *PacketBuffer) {
- newPkt := pkt.Clone()
-
- // Set network header.
- hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- netHeader := header.IPv4(hdr)
- newPkt.NetworkHeader = hdr
- length := int(netHeader.HeaderLength())
-
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- // Set transport header.
- switch protocol := netHeader.TransportProtocol(); protocol {
- case header.UDPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
- }
- case header.TCPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
- }
- }
- pkt.NetworkHeader = newPkt.NetworkHeader
- pkt.TransportHeader = newPkt.TransportHeader
-}
-
// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
+func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
var tuple connTrackTuple
netHeader := header.IPv4(pkt.NetworkHeader)
@@ -257,15 +218,8 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
// transport protocols.
func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- if hook == Prerouting {
- // Headers will not be set in Prerouting.
- // TODO(gvisor.dev/issue/170): Change this after parsing headers
- // code is added.
- parseHeaders(pkt)
- }
-
var dir ctDirection
- tuple, err := packetToTuple(*pkt, hook)
+ tuple, err := packetToTuple(pkt, hook)
if err != nil {
return nil, dir
}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
index 6b64cd37f..3eff141e6 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/forwarder.go
@@ -32,7 +32,7 @@ type pendingPacket struct {
nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
- pkt PacketBuffer
+ pkt *PacketBuffer
}
type forwardQueue struct {
@@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue {
return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
shouldWait := false
f.Lock()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 344d60baa..a6546cef0 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -33,6 +33,10 @@ const (
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -68,16 +72,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
return &f.id
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
-
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -96,13 +93,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
@@ -112,7 +109,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = netHeader
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+ return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -190,7 +197,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
LocalLinkAddress tcpip.LinkAddress
- Pkt PacketBuffer
+ Pkt *PacketBuffer
}
type fwdTestLinkEndpoint struct {
@@ -203,12 +210,12 @@ type fwdTestLinkEndpoint struct {
}
// InjectInbound injects an inbound packet.
-func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
@@ -251,7 +258,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -270,7 +277,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw
func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, gso, protocol, *pkt)
+ e.WritePacket(r, gso, protocol, pkt)
n++
}
@@ -280,7 +287,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: PacketBuffer{Data: vv},
+ Pkt: &PacketBuffer{Data: vv},
}
select {
@@ -361,8 +368,8 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -398,8 +405,8 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -429,8 +436,8 @@ func TestForwardingWithNoResolver(t *testing.T) {
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -459,16 +466,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
- buf[0] = 4
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -509,8 +515,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -554,10 +559,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
+ if !ok {
+ t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
}
- // The first 5 packets should not be forwarded so the the
- // sequemnce number should start with 5.
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
@@ -609,8 +618,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
- buf[0] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Data.ToView()
- if b[0] < 8 {
- t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 709ede3fa..4e9b404c8 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -43,11 +43,11 @@ const HookUnset = -1
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() IPTables {
+func DefaultTables() *IPTables {
// TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
// iotas.
- return IPTables{
- Tables: map[string]Table{
+ return &IPTables{
+ tables: map[string]Table{
TablenameNat: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
@@ -106,7 +106,7 @@ func DefaultTables() IPTables {
UserChains: map[string]int{},
},
},
- Priorities: map[Hook][]string{
+ priorities: map[Hook][]string{
Input: []string{TablenameNat, TablenameFilter},
Prerouting: []string{TablenameMangle, TablenameNat},
Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
@@ -158,6 +158,36 @@ func EmptyNatTable() Table {
}
}
+// GetTable returns table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ t, ok := it.tables[name]
+ return t, ok
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ it.tables[name] = table
+}
+
+// ModifyTables acquires write-lock and calls fn with internal name-to-table
+// map. This function can be used to update multiple tables atomically.
+func (it *IPTables) ModifyTables(fn func(map[string]Table)) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ fn(it.tables)
+}
+
+// GetPriorities returns slice of priorities associated with hook.
+func (it *IPTables) GetPriorities(hook Hook) []string {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.priorities[hook]
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
it.connections.HandlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.Priorities[hook] {
- table := it.Tables[tablename]
+ for _, tablename := range it.GetPriorities(hook) {
+ table, _ := it.GetTable(tablename)
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -321,7 +351,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, *pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
return RuleDrop, 0
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 36cc6275d..92e31643e 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
- // Set network header.
- if hook == Prerouting {
- parseHeaders(pkt)
- }
-
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index a3bd3e700..4a6a5c6f1 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -16,6 +16,7 @@ package stack
import (
"strings"
+ "sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -78,13 +79,17 @@ const (
// IPTables holds all the tables for a netstack.
type IPTables struct {
- // Tables maps table names to tables. User tables have arbitrary names.
- Tables map[string]Table
+ // mu protects tables and priorities.
+ mu sync.RWMutex
- // Priorities maps each hook to a list of table names. The order of the
+ // tables maps table names to tables. User tables have arbitrary names. mu
+ // needs to be locked for accessing.
+ tables map[string]Table
+
+ // priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
- // hook.
- Priorities map[Hook][]string
+ // hook. mu needs to be locked for accessing.
+ priorities map[Hook][]string
connections ConnTrackTable
}
@@ -245,7 +250,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 526c7d6ff..e28c23d66 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -467,8 +467,17 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- // The timer used to send the next router solicitation message.
- rtrSolicitTimer *time.Timer
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer *time.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
timer = time.AfterFunc(0, func() {
- ndp.nic.mu.RLock()
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
if done {
// If we reach this point, it means that the DAD timer fired after
// another goroutine already obtained the NIC lock and stopped DAD
// before this function obtained the NIC lock. Simply return here and do
// nothing further.
- ndp.nic.mu.RUnlock()
return
}
@@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
dadDone := remaining == 0
- ndp.nic.mu.RUnlock()
var err *tcpip.Error
if !dadDone {
- err = ndp.sendDADPacket(addr)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
}
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
if done {
// If we reach this point, it means that DAD was stopped after we released
// the NIC's read lock and before we obtained the write lock.
@@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// addr.
//
// addr must be a tentative IPv6 address on ndp's NIC.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
+//
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- // Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
@@ -750,7 +775,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicitTimer != nil {
+ if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
return
}
@@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
if ref == nil {
- ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
}
+ ndp.nic.mu.Unlock()
+
localAddr := ref.ep.ID().LocalAddress
r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
@@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() {
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
@@ -1881,7 +1926,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
@@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
- if remaining == 0 {
- ndp.rtrSolicitTimer = nil
- } else if ndp.rtrSolicitTimer != nil {
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
// we still reached this point, then we know the NIC
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
- ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
+ ndp.nic.mu.Unlock()
})
}
@@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicitTimer == nil {
+ if ndp.rtrSolicit.timer == nil {
// Nothing to do.
return
}
- ndp.rtrSolicitTimer.Stop()
- ndp.rtrSolicitTimer = nil
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
}
// initializeTempAddrState initializes state related to temporary SLAAC
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index b3d174cdd..58f1ebf60 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 05646e5e2..644c0d437 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -457,8 +457,20 @@ type ipv6AddrCandidate struct {
// remoteAddr must be a valid IPv6 address.
func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
@@ -568,11 +580,6 @@ const (
// promiscuous indicates that the NIC's promiscuous flag should be observed
// when getting a NIC's referenced network endpoint.
promiscuous
-
- // forceSpoofing indicates that the NIC should be assumed to be spoofing,
- // regardless of what the NIC's spoofing flag is when getting a NIC's
- // referenced network endpoint.
- forceSpoofing
)
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
@@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
- id := NetworkEndpointID{address}
-
n.mu.RLock()
var spoofingOrPromiscuous bool
@@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
spoofingOrPromiscuous = n.mu.spoofing
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
- case forceSpoofing:
- spoofingOrPromiscuous = true
}
- if ref, ok := n.mu.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
switch ref.getKind() {
case permanentExpired:
@@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.mu.endpoints[id]; ok {
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
- n.mu.Unlock()
return ref
}
// tryIncRef failing means the endpoint is scheduled to be removed once the
@@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
@@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
},
}, peb, temporary, static, false)
- n.mu.Unlock()
return ref
}
@@ -1153,7 +1161,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1167,7 +1175,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1212,12 +1220,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stack.stats.IP.PacketsReceived.Increment()
}
- netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
+ // The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(netHeader)
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1229,11 +1246,12 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
- if protocol == header.IPv4ProtocolNumber {
+ // Loopback traffic skips the prerouting chain.
+ if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok {
+ if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
return
}
@@ -1298,10 +1316,20 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
- pkt.Header = buffer.NewPrependable(linkHeaderLen)
+ // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
+ // by having lower layers explicity write each header instead of just
+ // pkt.Header.
+
+ // pkt may have set its NetworkHeader and TransportHeader. If we're
+ // forwarding, we'll have to copy them into pkt.Header.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
+ if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
+ }
+ if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
}
// WritePacket takes ownership of pkt, calculate numBytes first.
@@ -1318,7 +1346,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1332,13 +1360,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader == nil {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ pkt.TransportHeader = transHeader
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
+
+ if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@@ -1365,7 +1411,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index b01b3f476..31f865260 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,268 @@
package stack
import (
+ "math"
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ID implements NetworkEndpoint.ID.
+func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen implements NetworkEndpoint.PrefixLen.
+func (e *testIPv6Endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
@@ -44,7 +301,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket("", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 926df4d7b..1b5da6017 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -24,6 +24,8 @@ import (
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
type PacketBuffer struct {
+ _ noCopy
+
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
PacketBufferEntry
@@ -82,7 +84,32 @@ type PacketBuffer struct {
// VectorisedView but does not deep copy the underlying bytes.
//
// Clone also does not deep copy any of its other fields.
-func (pk PacketBuffer) Clone() PacketBuffer {
- pk.Data = pk.Data.Clone(nil)
- return pk
+//
+// FIXME(b/153685824): Data gets copied but not other header references.
+func (pk *PacketBuffer) Clone() *PacketBuffer {
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ Header: pk.Header,
+ LinkHeader: pk.LinkHeader,
+ NetworkHeader: pk.NetworkHeader,
+ TransportHeader: pk.TransportHeader,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ }
}
+
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock() {}
+func (*noCopy) Unlock() {}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index db89234e8..5cbc946b6 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -67,12 +67,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +100,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +118,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +150,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -168,6 +168,11 @@ type TransportProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -180,7 +185,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +194,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -242,7 +247,7 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It takes ownership of pkt. pkt.TransportHeader must have already
// been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length. It takes ownership of pkts and
@@ -251,7 +256,7 @@ type NetworkEndpoint interface {
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address. It takes ownership of pkt.
- WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -266,7 +271,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -313,6 +318,14 @@ type NetworkProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
@@ -327,7 +340,7 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -389,7 +402,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
// given route. pkts must not be zero length. It takes ownership of pkts and
@@ -431,7 +444,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 3d0e5cc6e..d65f8049e 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
+//
+// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
@@ -148,12 +150,14 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
@@ -199,7 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0ab4c3e19..a2190341c 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -52,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -424,12 +424,8 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
- // tablesMu protects iptables.
- tablesMu sync.RWMutex
-
- // tables are the iptables packet filtering and manipulation rules. The are
- // protected by tablesMu.`
- tables IPTables
+ // tables are the iptables packet filtering and manipulation rules.
+ tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -676,6 +672,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
@@ -778,7 +775,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1020,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
nic, ok := s.nics[id]
if !ok {
return tcpip.ErrUnknownNICID
@@ -1400,25 +1404,25 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
s.cleanupEndpoints[ep] = struct{}{}
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
@@ -1741,18 +1745,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() IPTables {
- s.tablesMu.RLock()
- t := s.tables
- s.tablesMu.RUnlock()
- return t
-}
-
-// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt IPTables) {
- s.tablesMu.Lock()
- s.tables = ipt
- s.tablesMu.Unlock()
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 1a2cf007c..ffef9bc2c 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -52,6 +52,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -90,30 +94,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fakeNetHeaderLen)
-
// Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -132,24 +134,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
+ pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
+ pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
+ pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ f.HandlePacket(r, pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -163,7 +160,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
@@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -292,8 +300,8 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
- buf[0] = 3
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -304,8 +312,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
- buf[0] = 1
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 1
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -316,8 +324,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
- buf[0] = 2
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 2
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -328,7 +336,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -340,7 +348,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -362,7 +370,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -420,7 +428,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
if promiscuous {
if err := s.SetPromiscuousMode(nicID, true); err != nil {
@@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -2263,7 +2271,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2344,8 +2352,8 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
- buf[0] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9a33ed375..118b449d5 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,7 +15,6 @@
package stack
import (
- "container/heap"
"fmt"
"math/rand"
@@ -23,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
)
type protocolIDs struct {
@@ -43,14 +43,14 @@ type transportEndpoints struct {
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
- if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
return
}
delete(eps.endpoints, id)
@@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
@@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
@@ -214,23 +214,22 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
demux: d,
netProto: netProto,
transProto: transProto,
- reuse: reusePort,
}
epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ return multiPortEp.singleRegisterEndpoint(t, flags)
}
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
-func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
- if multiPortEp.unregisterEndpoint(t) {
+ if multiPortEp.unregisterEndpoint(t, flags) {
delete(epsByNIC.endpoints, bindToDevice)
}
return len(epsByNIC.endpoints) == 0
@@ -251,7 +250,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -279,10 +278,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
return err
}
}
@@ -290,35 +289,6 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
-type transportEndpointHeap []TransportEndpoint
-
-var _ heap.Interface = (*transportEndpointHeap)(nil)
-
-func (h *transportEndpointHeap) Len() int {
- return len(*h)
-}
-
-func (h *transportEndpointHeap) Less(i, j int) bool {
- return (*h)[i].UniqueID() < (*h)[j].UniqueID()
-}
-
-func (h *transportEndpointHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *transportEndpointHeap) Push(x interface{}) {
- *h = append(*h, x.(TransportEndpoint))
-}
-
-func (h *transportEndpointHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- old[n-1] = nil
- *h = old[:n-1]
- return x
-}
-
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
@@ -334,9 +304,10 @@ type multiPortEndpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- endpoints transportEndpointHeap
- // reuse indicates if more than one endpoint is allowed.
- reuse bool
+ // endpoints stores the transport endpoints in the order in which they
+ // were bound. This is required for UDP SO_REUSEADDR.
+ endpoints []TransportEndpoint
+ flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,6 +333,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[0]
}
+ if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ return mpep.endpoints[len(mpep.endpoints)-1]
+ }
+
payload := []byte{
byte(id.LocalPort),
byte(id.LocalPort >> 8),
@@ -379,7 +354,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
@@ -401,40 +376,47 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
+ bits := flags.Bits()
+
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if !ep.reuse || !reusePort {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
return tcpip.ErrPortInUse
}
}
- heap.Push(&ep.endpoints, t)
+ ep.endpoints = append(ep.endpoints, t)
+ ep.flags.AddRef(bits)
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
-func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
for i, endpoint := range ep.endpoints {
if endpoint == t {
- heap.Remove(&ep.endpoints, i)
+ copy(ep.endpoints[i:], ep.endpoints[i+1:])
+ ep.endpoints[len(ep.endpoints)-1] = nil
+ ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
+
+ ep.flags.DropRef(flags.Bits())
break
}
}
return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
- // TODO(eyalsoha): Why?
- reusePort = false
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
@@ -454,15 +436,20 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.endpoints[id] = epsByNIC
}
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep, bindToDevice)
+ eps.unregisterEndpoint(id, ep, flags, bindToDevice)
}
}
}
@@ -470,7 +457,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -520,7 +507,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -544,7 +531,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 2474a7db3..73dada928 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -127,7 +128,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -165,7 +166,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a611e44ab..7e8b84867 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -83,12 +84,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
+ hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -153,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -198,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false, /* reuse */
- 0, /* bindtoDevice */
+ ports.Flags{},
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -215,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -232,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -289,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -324,6 +326,17 @@ func (*fakeTransportProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeTransportProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
+ if !ok {
+ return false
+ }
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(fakeTransHeaderLen)
+ return true
+}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
@@ -369,7 +382,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -380,7 +393,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -391,7 +404,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -446,7 +459,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -457,7 +470,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -468,7 +481,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -623,7 +636,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: req.ToVectorisedView(),
})
@@ -642,11 +655,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- hdrs := p.Pkt.Data.ToView()
- if dst := hdrs[0]; dst != 3 {
+ if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := hdrs[1]; src != 1 {
+ if src := p.Pkt.NetworkHeader[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9ce625c17..7e5c79776 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b1d820372..8ce294002 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,7 +111,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -140,11 +141,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -450,7 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: data.ToVectorisedView(),
TransportHeader: buffer.View(icmpv4),
@@ -481,7 +477,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: dataVV,
TransportHeader: buffer.View(icmpv6),
@@ -511,6 +507,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
@@ -611,14 +608,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -743,7 +740,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -805,7 +802,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3c47692b2..74ef6541e 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -124,6 +124,16 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 23158173d..baf08eda6 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -132,11 +132,6 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (stack.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
@@ -298,7 +293,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee754a5a..a406d815e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !e.associated {
@@ -348,7 +343,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
switch e.NetProto {
case header.IPv4ProtocolNumber:
if !e.associated {
- if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
return 0, nil, err
@@ -357,7 +352,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: buffer.View(payloadBytes).ToVectorisedView(),
Owner: e.owner,
@@ -584,7 +579,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -632,8 +627,9 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
},
}
- networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- combinedVV := networkHeader.ToVectorisedView()
+ headers := append(buffer.View(nil), pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV := headers.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
packet.timestampNS = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f38eb6833..e26f01fae 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -86,10 +86,6 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = [
- "flaky",
- ],
deps = [
":tcp",
"//pkg/sync",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e6a23c978..ad197e8db 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -238,13 +239,14 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.mu.Lock()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil {
n.mu.Unlock()
n.Close()
return nil, err
}
n.isRegistered = true
+ n.registeredReusePort = n.reusePort
return n, nil
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index e4a06c9e1..7da93dcc4 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
Data: data,
Hash: tf.txHash,
Owner: owner,
}
- buildTCPHdr(r, tf, &pkt, gso)
+ buildTCPHdr(r, tf, pkt, gso)
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 6062ca916..047704c80 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -186,7 +186,7 @@ func (d *dispatcher) wait() {
}
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5ba972f1..6e4d607da 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,7 +63,8 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +74,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -430,6 +465,10 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // registeredReusePort is set if the current endpoint registration was
+ // done with SO_REUSEPORT enabled.
+ registeredReusePort bool
+
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -986,8 +1025,9 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice)
e.isRegistered = false
+ e.registeredReusePort = false
}
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
@@ -1051,8 +1091,9 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice)
e.isRegistered = false
+ e.registeredReusePort = false
}
if e.isPortReserved {
@@ -1172,11 +1213,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -2058,10 +2094,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice)
if err != nil {
return err
}
+ e.registeredReusePort = e.reusePort
} else {
// The endpoint doesn't have a local port yet, so try to get
// one. Make sure that it isn't one that will result in the same
@@ -2093,12 +2130,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
+ switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) {
case nil:
// Port picking successful. Save the details of
// the selected port.
e.ID = id
e.boundBindToDevice = e.bindToDevice
+ e.registeredReusePort = e.reusePort
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -2296,12 +2334,13 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil {
return err
}
e.isRegistered = true
e.setEndpointState(StateListen)
+ e.registeredReusePort = e.reusePort
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
@@ -2462,7 +2501,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2481,7 +2520,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index fc43c11e2..cbb779666 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.EndpointState() {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- case StateError, StateClose:
+ case epState.closed():
for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
@@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
- if state.connected() || state == StateTimeWait {
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
// as the endpoint is still being loaded and the stack reference is not
// yet initialized.
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
@@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- state := e.origEndpointState
- switch state {
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
closed := e.closed
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
- if state == StateFinWait2 && closed {
+ if epState == StateFinWait2 && closed {
// If the endpoint has been closed then make sure we notify so
// that the FIN_WAIT2 timer is started after a restore.
e.notifyProtocolGoroutine(notifyClose)
}
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
+ case epState == StateClose:
if e.isPortReserved {
tcpip.AsyncLoading.Add(1)
go func() {
@@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
- case StateError:
+ case epState == StateError:
e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
-
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 704d01c64..070b634b4 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a2a7ddeb..73b8a6782 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,6 +21,7 @@
package tcp
import (
+ "fmt"
"runtime"
"strings"
"time"
@@ -206,7 +207,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -217,7 +218,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -490,6 +491,26 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
return &p.synRcvdCount
}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(offset)
+ if !ok {
+ panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset))
+ }
+ }
+
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ return true
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 074edded6..0280892a8 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -35,6 +35,7 @@ type segment struct {
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -60,13 +61,14 @@ type segment struct {
xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader)
s.rcvdTime = time.Now()
return s
}
@@ -146,12 +148,6 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h, ok := s.data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
- hdr := header.TCP(h)
-
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -162,16 +158,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(hdr.DataOffset())
- if offset < header.TCPMinimumSize {
- return false
- }
- hdrWithOpts, ok := s.data.PullUp(offset)
- if !ok {
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
return false
}
- s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -180,22 +172,19 @@ func (s *segment) parse() bool {
if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
s.csumValid = true
verifyChecksum = false
- s.data.TrimFront(offset)
}
if verifyChecksum {
- hdr = header.TCP(hdrWithOpts)
- s.csum = hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = hdr.CalculateChecksum(xsum)
- s.data.TrimFront(offset)
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
- s.ackNumber = seqnum.Value(hdr.AckNumber())
- s.flags = hdr.Flags()
- s.window = seqnum.Size(hdr.WindowSize())
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 06dc9b7d7..acacb42e4 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
@@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -816,6 +833,25 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
panic("Netstack queues FIN segments without data.")
}
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ // If the entire segment cannot be accomodated in the receiver
+ // advertized window, skip splitting and sending of the segment.
+ // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered
+ // by a probe timer. On this condition, it defers the segment
+ // split and transmit to a short probe timer.
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split the
+ // segment right here if there are no pending segments.
+ // If there are pending segments, segment transmits are deferred
+ // to the retransmit timer handler.
+ if s.sndUna != s.sndNxt && !segEnd.LessThan(end) {
+ return false
+ }
+
if !seg.sequenceNumber.LessThan(end) {
return false
}
@@ -824,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // The segment size limit is computed as a function of sender congestion
+ // window and MSS. When sender congestion window is > 1, this limit can
+ // be larger than MSS. Ensure that the currently available send space
+ // is not greater than minimum of this limit and MSS.
if available > limit {
available = limit
}
+ if available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
if seg.data.Size() > available {
s.splitSeg(seg, available)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 7b1d72cf4..9721f6caf 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: s,
})
}
@@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) {
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: c.BuildSegment(payload, h),
})
}
@@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
}
@@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 647b2067a..df5efbf6a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,7 @@
package udp
import (
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -102,7 +103,7 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
- reusePort bool
+ portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
@@ -213,7 +214,7 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
@@ -247,11 +248,6 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -430,24 +426,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
route = &e.route
dstPort = e.dstPort
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may need to
- // change in Route.Resolve() call below.
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != StateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
}
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ return
}
} else {
// Reject destination address if it goes through a different
@@ -478,10 +483,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
route = &r
dstPort = dst.Port
+ resolve = route.Resolve
}
if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
+ if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -552,10 +558,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
case tcpip.ReusePortOption:
e.mu.Lock()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.mu.Unlock()
case tcpip.V6OnlyOption:
@@ -789,11 +798,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
case tcpip.ReuseAddressOption:
- return false, nil
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
case tcpip.ReusePortOption:
e.mu.RLock()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.mu.RUnlock()
return v, nil
@@ -921,7 +934,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: data,
TransportHeader: buffer.View(udp),
@@ -958,6 +975,11 @@ func (e *endpoint) Disconnect() *tcpip.Error {
id stack.TransportEndpointID
btd tcpip.NICID
)
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -970,16 +992,17 @@ func (e *endpoint) Disconnect() *tcpip.Error {
return err
}
e.state = StateBound
+ boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice)
e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -1051,6 +1074,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
+ oldPortFlags := e.boundPortFlags
+
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
@@ -1058,7 +1083,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
@@ -1122,20 +1147,15 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- // FIXME(b/129164367): Support SO_REUSEADDR.
- MostRecent: false,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice)
if err != nil {
return id, e.bindToDevice, err
}
- e.boundPortFlags = flags
id.LocalPort = port
}
+ e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
e.boundPortFlags = ports.Flags{}
@@ -1269,18 +1289,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- pkt.Data.TrimFront(header.UDPMinimumSize)
-
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
@@ -1336,7 +1354,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
defer e.mu.RUnlock()
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a674ceb68..c67e0ba95 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt stack.PacketBuffer
+ pkt *stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 52af6de22..4218e7d03 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -66,15 +66,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
- // Get the header then trim it from the view.
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Malformed packet.
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- if int(header.UDP(h).Length()) > pkt.Data.Size() {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -121,7 +115,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -130,9 +124,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
// For example, a raw or packet socket may use what UDP
// considers an unreachable destination. Thus we deep copy pkt
// to prevent multiple ownership and SR errors.
- newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- payload := newNetHeader.ToVectorisedView()
- payload.Append(pkt.Data.ToView().ToVectorisedView())
+ newHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ newHeader = append(newHeader, pkt.TransportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -140,9 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
})
case header.IPv6AddressSize:
@@ -164,11 +160,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
+ payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader})
payload.Append(pkt.Data)
payload.CapLength(payloadLen)
@@ -177,9 +173,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
})
}
return true
@@ -201,6 +198,18 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
+ // Packet is too small
+ return false
+ }
+ pkt.TransportHeader = h
+ pkt.Data.TrimFront(header.UDPMinimumSize)
+ return true
+}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8acaa607a..313a3f117 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -440,10 +440,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
})
}
@@ -487,10 +485,8 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
})
}
@@ -1720,6 +1716,58 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
}
}
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ c.t.Helper()
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
// TestShutdownRead verifies endpoint read shutdown and error
// stats increment on packet receive.
func TestShutdownRead(t *testing.T) {
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
deleted file mode 100644
index 2dcba84ae..000000000
--- a/pkg/tmutex/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tmutex",
- srcs = ["tmutex.go"],
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "tmutex_test",
- size = "medium",
- srcs = ["tmutex_test.go"],
- library = ":tmutex",
- deps = ["//pkg/sync"],
-)
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
deleted file mode 100644
index c4685020d..000000000
--- a/pkg/tmutex/tmutex.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tmutex provides the implementation of a mutex that implements an
-// efficient TryLock function in addition to Lock and Unlock.
-package tmutex
-
-import (
- "sync/atomic"
-)
-
-// Mutex is a mutual exclusion primitive that implements TryLock in addition
-// to Lock and Unlock.
-type Mutex struct {
- v int32
- ch chan struct{}
-}
-
-// Init initializes the mutex.
-func (m *Mutex) Init() {
- m.v = 1
- m.ch = make(chan struct{}, 1)
-}
-
-// Lock acquires the mutex. If it is currently held by another goroutine, Lock
-// will wait until it has a chance to acquire it.
-func (m *Mutex) Lock() {
- // Uncontended case.
- if atomic.AddInt32(&m.v, -1) == 0 {
- return
- }
-
- for {
- // Try to acquire the mutex again, at the same time making sure
- // that m.v is negative, which indicates to the owner of the
- // lock that it is contended, which will force it to try to wake
- // someone up when it releases the mutex.
- if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
- return
- }
-
- // Wait for the mutex to be released before trying again.
- <-m.ch
- }
-}
-
-// TryLock attempts to acquire the mutex without blocking. If the mutex is
-// currently held by another goroutine, it fails to acquire it and returns
-// false.
-func (m *Mutex) TryLock() bool {
- v := atomic.LoadInt32(&m.v)
- if v <= 0 {
- return false
- }
- return atomic.CompareAndSwapInt32(&m.v, 1, 0)
-}
-
-// Unlock releases the mutex.
-func (m *Mutex) Unlock() {
- if atomic.SwapInt32(&m.v, 1) == 0 {
- // There were no pending waiters.
- return
- }
-
- // Wake some waiter up.
- select {
- case m.ch <- struct{}{}:
- default:
- }
-}
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
deleted file mode 100644
index 05540696a..000000000
--- a/pkg/tmutex/tmutex_test.go
+++ /dev/null
@@ -1,258 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tmutex
-
-import (
- "fmt"
- "runtime"
- "sync/atomic"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-func TestBasicLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- m.Lock()
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- ch <- struct{}{}
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-
- // Make sure we can lock and unlock again.
- m.Lock()
- m.Unlock()
-}
-
-func TestTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Try to lock. It should succeed.
- if !m.TryLock() {
- t.Fatalf("TryLock failed on unlocked mutex")
- }
-
- // Try to lock again, it should now fail.
- if m.TryLock() {
- t.Fatalf("TryLock succeeded on locked mutex")
- }
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-}
-
-func TestMutualExclusion(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Test mutual exclusion by running "gr" goroutines concurrently, and
- // have each one increment a counter "iters" times within the critical
- // section established by the mutex.
- //
- // If at the end the counter is not gr * iters, then we know that
- // goroutines ran concurrently within the critical section.
- //
- // If one of the goroutines doesn't complete, it's likely a bug that
- // causes to it to wait forever.
- const gr = 1000
- const iters = 100000
- v := 0
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(1)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- if v != gr*iters {
- t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
- }
-}
-
-func TestMutualExclusionWithTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Similar to the previous, with the addition of some goroutines that
- // only increment the count if TryLock succeeds.
- const gr = 1000
- const iters = 100000
- total := int64(gr * iters)
- var tryTotal int64
- v := int64(0)
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(2)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- go func() {
- local := int64(0)
- for j := 0; j < iters; j++ {
- if m.TryLock() {
- v++
- m.Unlock()
- local++
- }
- }
- atomic.AddInt64(&tryTotal, local)
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- t.Logf("tryTotal = %d", tryTotal)
- total += tryTotal
-
- if v != total {
- t.Fatalf("Bad count: got %v, want %v", v, total)
- }
-}
-
-// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
-// differences:
-//
-// - The number of goroutines is variable, with the maximum value depending on
-// GOMAXPROCS.
-//
-// - The number of iterations per benchmark is controlled by the benchmarking
-// framework.
-//
-// - Care is taken to ensure that all goroutines participating in the benchmark
-// have been created before the benchmark begins.
-func BenchmarkTmutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m Mutex
- m.Init()
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
-
-// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
-// a comparison point.
-func BenchmarkSyncMutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m sync.Mutex
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
diff --git a/runsc/boot/filter/extra_filters_msan.go b/runsc/boot/filter/extra_filters_msan.go
index 5e5a3c998..209e646a7 100644
--- a/runsc/boot/filter/extra_filters_msan.go
+++ b/runsc/boot/filter/extra_filters_msan.go
@@ -26,6 +26,8 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
Report("MSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
+ syscall.SYS_CLONE: {},
+ syscall.SYS_MMAP: {},
syscall.SYS_SCHED_GETAFFINITY: {},
syscall.SYS_SET_ROBUST_LIST: {},
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index b98a1eb50..e83584b82 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -293,11 +293,11 @@ func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter,
procArgs.MountNamespace = mns
// Resolve the executable path from working dir and environment.
- f, err := user.ResolveExecutablePath(ctx, procArgs.Credentials, procArgs.MountNamespace, procArgs.Envv, procArgs.WorkingDirectory, procArgs.Argv[0])
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
if err != nil {
- return fmt.Errorf("searching for executable %q, cwd: %q, envv: %q: %v", procArgs.Argv[0], procArgs.WorkingDirectory, procArgs.Envv, err)
+ return err
}
- procArgs.Filename = f
+ procArgs.Filename = resolved
return nil
}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index f802bc9fb..002479612 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -1056,7 +1056,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in
return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
}
- s.FillDefaultIPTables()
+ s.FillIPTablesMetadata()
return &s, nil
}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 6c84f0794..8eeb43e79 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -96,11 +96,11 @@ func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounte
procArgs.MountNamespaceVFS2 = mns
// Resolve the executable path from working dir and environment.
- f, err := user.ResolveExecutablePathVFS2(ctx, procArgs.Credentials, procArgs.MountNamespaceVFS2, procArgs.Envv, procArgs.WorkingDirectory, procArgs.Argv[0])
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
if err != nil {
- return fmt.Errorf("searching for executable %q, cwd: %q, envv: %q: %v", procArgs.Argv[0], procArgs.WorkingDirectory, procArgs.Envv, err)
+ return err
}
- procArgs.Filename = f
+ procArgs.Filename = resolved
return nil
}
@@ -272,7 +272,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF
case "ro":
opts.ReadOnly = true
case "noatime":
- // TODO(gvisor.dev/issue/1193): Implement MS_NOATIME.
+ opts.Flags.NoATime = true
case "noexec":
opts.Flags.NoExec = true
default:
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
index c087e1a3c..7e34a284a 100644
--- a/runsc/cgroup/BUILD
+++ b/runsc/cgroup/BUILD
@@ -20,4 +20,8 @@ go_test(
srcs = ["cgroup_test.go"],
library = ":cgroup",
tags = ["local"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
)
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index ef01820ef..e5cc9d622 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -43,6 +43,7 @@ var controllers = map[string]config{
"blkio": config{ctrlr: &blockIO{}},
"cpu": config{ctrlr: &cpu{}},
"cpuset": config{ctrlr: &cpuSet{}},
+ "hugetlb": config{ctrlr: &hugeTLB{}, optional: true},
"memory": config{ctrlr: &memory{}},
"net_cls": config{ctrlr: &networkClass{}},
"net_prio": config{ctrlr: &networkPrio{}},
@@ -52,7 +53,6 @@ var controllers = map[string]config{
// irrelevant for a sandbox.
"devices": config{ctrlr: &noop{}},
"freezer": config{ctrlr: &noop{}},
- "hugetlb": config{ctrlr: &noop{}, optional: true},
"perf_event": config{ctrlr: &noop{}},
"rdma": config{ctrlr: &noop{}, optional: true},
"systemd": config{ctrlr: &noop{}},
@@ -125,7 +125,7 @@ func fillFromAncestor(path string) (string, error) {
return val, nil
}
- // File is not set, recurse to parent and then set here.
+ // File is not set, recurse to parent and then set here.
name := filepath.Base(path)
parent := filepath.Dir(filepath.Dir(path))
val, err = fillFromAncestor(filepath.Join(parent, name))
@@ -446,7 +446,13 @@ func (*cpu) set(spec *specs.LinuxResources, path string) error {
if err := setOptionalValueInt(path, "cpu.cfs_quota_us", spec.CPU.Quota); err != nil {
return err
}
- return setOptionalValueUint(path, "cpu.cfs_period_us", spec.CPU.Period)
+ if err := setOptionalValueUint(path, "cpu.cfs_period_us", spec.CPU.Period); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint(path, "cpu.rt_period_us", spec.CPU.RealtimePeriod); err != nil {
+ return err
+ }
+ return setOptionalValueInt(path, "cpu.rt_runtime_us", spec.CPU.RealtimeRuntime)
}
type cpuSet struct{}
@@ -487,13 +493,17 @@ func (*blockIO) set(spec *specs.LinuxResources, path string) error {
}
for _, dev := range spec.BlockIO.WeightDevice {
- val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, dev.Weight)
- if err := setValue(path, "blkio.weight_device", val); err != nil {
- return err
+ if dev.Weight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.Weight)
+ if err := setValue(path, "blkio.weight_device", val); err != nil {
+ return err
+ }
}
- val = fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, dev.LeafWeight)
- if err := setValue(path, "blkio.leaf_weight_device", val); err != nil {
- return err
+ if dev.LeafWeight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.LeafWeight)
+ if err := setValue(path, "blkio.leaf_weight_device", val); err != nil {
+ return err
+ }
}
}
if err := setThrottle(path, "blkio.throttle.read_bps_device", spec.BlockIO.ThrottleReadBpsDevice); err != nil {
@@ -545,9 +555,22 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error {
type pids struct{}
func (*pids) set(spec *specs.LinuxResources, path string) error {
- if spec.Pids == nil {
+ if spec.Pids == nil || spec.Pids.Limit <= 0 {
return nil
}
val := strconv.FormatInt(spec.Pids.Limit, 10)
return setValue(path, "pids.max", val)
}
+
+type hugeTLB struct{}
+
+func (*hugeTLB) set(spec *specs.LinuxResources, path string) error {
+ for _, limit := range spec.HugepageLimits {
+ name := fmt.Sprintf("hugetlb.%s.limit_in_bytes", limit.Pagesize)
+ val := strconv.FormatUint(limit.Limit, 10)
+ if err := setValue(path, name, val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 548c80e9a..4db5ee5c3 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -15,7 +15,14 @@
package cgroup
import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
"testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
func TestUninstallEnoent(t *testing.T) {
@@ -65,3 +72,578 @@ func TestCountCpuset(t *testing.T) {
})
}
}
+
+func uint16Ptr(v uint16) *uint16 {
+ return &v
+}
+
+func uint32Ptr(v uint32) *uint32 {
+ return &v
+}
+
+func int64Ptr(v int64) *int64 {
+ return &v
+}
+
+func uint64Ptr(v uint64) *uint64 {
+ return &v
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+func checkDir(t *testing.T, dir string, contents map[string]string) {
+ all, err := ioutil.ReadDir(dir)
+ if err != nil {
+ t.Fatalf("ReadDir(%q): %v", dir, err)
+ }
+ fileCount := 0
+ for _, file := range all {
+ if file.IsDir() {
+ // Only want to compare files.
+ continue
+ }
+ fileCount++
+
+ want, ok := contents[file.Name()]
+ if !ok {
+ t.Errorf("file not expected: %q", file.Name())
+ continue
+ }
+ gotBytes, err := ioutil.ReadFile(filepath.Join(dir, file.Name()))
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+ got := strings.TrimSuffix(string(gotBytes), "\n")
+ if got != want {
+ t.Errorf("wrong file content, file: %q, want: %q, got: %q", file.Name(), want, got)
+ }
+ }
+ if fileCount != len(contents) {
+ t.Errorf("file is missing, want: %v, got: %v", contents, all)
+ }
+}
+
+func makeLinuxWeightDevice(major, minor int64, weight, leafWeight *uint16) specs.LinuxWeightDevice {
+ rv := specs.LinuxWeightDevice{
+ Weight: weight,
+ LeafWeight: leafWeight,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func makeLinuxThrottleDevice(major, minor int64, rate uint64) specs.LinuxThrottleDevice {
+ rv := specs.LinuxThrottleDevice{
+ Rate: rate,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func TestBlockIO(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxBlockIO
+ wants map[string]string
+ }{
+ {
+ name: "simple",
+ spec: &specs.LinuxBlockIO{
+ Weight: uint16Ptr(1),
+ LeafWeight: uint16Ptr(2),
+ },
+ wants: map[string]string{
+ "blkio.weight": "1",
+ "blkio.leaf_weight": "2",
+ },
+ },
+ {
+ name: "weight_device",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, uint16Ptr(3), uint16Ptr(4)),
+ },
+ },
+ wants: map[string]string{
+ "blkio.weight_device": "1:2 3",
+ "blkio.leaf_weight_device": "1:2 4",
+ },
+ },
+ {
+ name: "weight_device_nil_values",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, nil, nil),
+ },
+ },
+ },
+ {
+ name: "throttle",
+ spec: &specs.LinuxBlockIO{
+ ThrottleReadBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(1, 2, 3),
+ },
+ ThrottleReadIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(4, 5, 6),
+ },
+ ThrottleWriteBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(7, 8, 9),
+ },
+ ThrottleWriteIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(10, 11, 12),
+ },
+ },
+ wants: map[string]string{
+ "blkio.throttle.read_bps_device": "1:2 3",
+ "blkio.throttle.read_iops_device": "4:5 6",
+ "blkio.throttle.write_bps_device": "7:8 9",
+ "blkio.throttle.write_iops_device": "10:11 12",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxBlockIO{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ BlockIO: tc.spec,
+ }
+ ctrlr := blockIO{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPU(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Shares: uint64Ptr(1),
+ Quota: int64Ptr(2),
+ Period: uint64Ptr(3),
+ RealtimeRuntime: int64Ptr(4),
+ RealtimePeriod: uint64Ptr(5),
+ },
+ wants: map[string]string{
+ "cpu.shares": "1",
+ "cpu.cfs_quota_us": "2",
+ "cpu.cfs_period_us": "3",
+ "cpu.rt_runtime_us": "4",
+ "cpu.rt_period_us": "5",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpu{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPUSet(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Cpus: "foo",
+ Mems: "bar",
+ },
+ wants: map[string]string{
+ "cpuset.cpus": "foo",
+ "cpuset.mems": "bar",
+ },
+ },
+ // Don't test nil values because they are copied from the parent.
+ // See TestCPUSetAncestor().
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+// TestCPUSetAncestor checks that, when not available, value is read from
+// parent directory.
+func TestCPUSetAncestor(t *testing.T) {
+ // Prepare master directory with cgroup files that will be propagated to
+ // children.
+ grandpa, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(grandpa)
+
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.cpus"), []byte("parent-cpus"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.mems"), []byte("parent-mems"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ }{
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // Create empty files in intermediate directory. They should be ignored
+ // when reading, and then populated from parent.
+ parent, err := ioutil.TempDir(grandpa, "parent")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(parent)
+ if _, err := os.Create(filepath.Join(parent, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(parent, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ // cgroup files mmust exist.
+ dir, err := ioutil.TempDir(parent, "child")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ want := map[string]string{
+ "cpuset.cpus": "parent-cpus",
+ "cpuset.mems": "parent-mems",
+ }
+ // Both path and dir must have been populated from grandpa.
+ checkDir(t, parent, want)
+ checkDir(t, dir, want)
+ })
+ }
+}
+
+func TestHugeTlb(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec []specs.LinuxHugepageLimit
+ wants map[string]string
+ }{
+ {
+ name: "single",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ },
+ },
+ {
+ name: "multiple",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ {
+ Pagesize: "2G",
+ Limit: 456,
+ },
+ {
+ Pagesize: "1P",
+ Limit: 789,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ "hugetlb.2G.limit_in_bytes": "456",
+ "hugetlb.1P.limit_in_bytes": "789",
+ },
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ HugepageLimits: tc.spec,
+ }
+ ctrlr := hugeTLB{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestMemory(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxMemory
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxMemory{
+ Limit: int64Ptr(1),
+ Reservation: int64Ptr(2),
+ Swap: int64Ptr(3),
+ Kernel: int64Ptr(4),
+ KernelTCP: int64Ptr(5),
+ Swappiness: uint64Ptr(6),
+ DisableOOMKiller: boolPtr(true),
+ },
+ wants: map[string]string{
+ "memory.limit_in_bytes": "1",
+ "memory.soft_limit_in_bytes": "2",
+ "memory.memsw.limit_in_bytes": "3",
+ "memory.kmem.limit_in_bytes": "4",
+ "memory.kmem.tcp.limit_in_bytes": "5",
+ "memory.swappiness": "6",
+ "memory.oom_control": "1",
+ },
+ },
+ {
+ // Disable OOM killer should only write when set to true.
+ name: "oomkiller",
+ spec: &specs.LinuxMemory{
+ DisableOOMKiller: boolPtr(false),
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxMemory{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Memory: tc.spec,
+ }
+ ctrlr := memory{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkClass(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ ClassID: uint32Ptr(1),
+ },
+ wants: map[string]string{
+ "net_cls.classid": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkClass{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkPriority(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ Priorities: []specs.LinuxInterfacePriority{
+ {
+ Name: "foo",
+ Priority: 1,
+ },
+ },
+ },
+ wants: map[string]string{
+ "net_prio.ifpriomap": "foo 1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkPrio{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestPids(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxPids
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxPids{Limit: 1},
+ wants: map[string]string{
+ "pids.max": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxPids{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Pids: tc.spec,
+ }
+ ctrlr := pids{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 10448a759..3966e2d21 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -306,7 +306,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
}
// Replace the current spec, with the clean spec with symlinks resolved.
- if err := setupMounts(spec.Mounts, root); err != nil {
+ if err := setupMounts(conf, spec.Mounts, root); err != nil {
Fatalf("error setting up FS: %v", err)
}
@@ -322,7 +322,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
}
// Check if root needs to be remounted as readonly.
- if spec.Root.Readonly {
+ if spec.Root.Readonly || conf.Overlay {
// If root is a mount point but not read-only, we can change mount options
// to make it read-only for extra safety.
log.Infof("Remounting root as readonly: %q", root)
@@ -346,7 +346,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
// setupMounts binds mount all mounts specified in the spec in their correct
// location inside root. It will resolve relative paths and symlinks. It also
// creates directories as needed.
-func setupMounts(mounts []specs.Mount, root string) error {
+func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error {
for _, m := range mounts {
if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
continue
@@ -358,6 +358,11 @@ func setupMounts(mounts []specs.Mount, root string) error {
}
flags := specutils.OptionsToFlags(m.Options) | syscall.MS_BIND
+ if conf.Overlay {
+ // Force mount read-only if writes are not going to be sent to it.
+ flags |= syscall.MS_RDONLY
+ }
+
log.Infof("Mounting src: %q, dst: %q, flags: %#x", m.Source, dst, flags)
if err := specutils.Mount(m.Source, dst, m.Type, flags); err != nil {
return fmt.Errorf("mounting %v: %v", m, err)
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index e7715b6f7..cd76645bd 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -20,6 +20,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "math"
"os"
"path"
"path/filepath"
@@ -53,9 +54,8 @@ func waitForProcessList(cont *Container, want []*control.Process) error {
err = fmt.Errorf("error getting process data from container: %v", err)
return &backoff.PermanentError{Err: err}
}
- if r, err := procListsEqual(got, want); !r {
- return fmt.Errorf("container got process list: %s, want: %s: error: %v",
- procListToString(got), procListToString(want), err)
+ if !procListsEqual(got, want) {
+ return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want))
}
return nil
}
@@ -92,36 +92,72 @@ func blockUntilWaitable(pid int) error {
return err
}
-// procListsEqual is used to check whether 2 Process lists are equal for all
-// implemented fields.
-func procListsEqual(got, want []*control.Process) (bool, error) {
- if len(got) != len(want) {
- return false, nil
- }
- for i := range got {
- pd1 := got[i]
- pd2 := want[i]
- // Zero out timing dependant fields.
- pd1.Time = ""
- pd1.STime = ""
- pd1.C = 0
- // Ignore TTY field too, since it's not relevant in the cases
- // where we use this method. Tests that care about the TTY
- // field should check for it themselves.
- pd1.TTY = ""
- pd1Json, err := control.ProcessListToJSON([]*control.Process{pd1})
- if err != nil {
- return false, err
+// procListsEqual is used to check whether 2 Process lists are equal. Fields
+// set to -1 in wants are ignored. Timestamp and threads fields are always
+// ignored.
+func procListsEqual(gots, wants []*control.Process) bool {
+ if len(gots) != len(wants) {
+ return false
+ }
+ for i := range gots {
+ got := gots[i]
+ want := wants[i]
+
+ if want.UID != math.MaxUint32 && want.UID != got.UID {
+ return false
}
- pd2Json, err := control.ProcessListToJSON([]*control.Process{pd2})
- if err != nil {
- return false, err
+ if want.PID != -1 && want.PID != got.PID {
+ return false
}
- if pd1Json != pd2Json {
- return false, nil
+ if want.PPID != -1 && want.PPID != got.PPID {
+ return false
+ }
+ if len(want.TTY) != 0 && want.TTY != got.TTY {
+ return false
+ }
+ if len(want.Cmd) != 0 && want.Cmd != got.Cmd {
+ return false
}
}
- return true, nil
+ return true
+}
+
+type processBuilder struct {
+ process control.Process
+}
+
+func newProcessBuilder() *processBuilder {
+ return &processBuilder{
+ process: control.Process{
+ UID: math.MaxUint32,
+ PID: -1,
+ PPID: -1,
+ },
+ }
+}
+
+func (p *processBuilder) Cmd(cmd string) *processBuilder {
+ p.process.Cmd = cmd
+ return p
+}
+
+func (p *processBuilder) PID(pid kernel.ThreadID) *processBuilder {
+ p.process.PID = pid
+ return p
+}
+
+func (p *processBuilder) PPID(ppid kernel.ThreadID) *processBuilder {
+ p.process.PPID = ppid
+ return p
+}
+
+func (p *processBuilder) UID(uid auth.KUID) *processBuilder {
+ p.process.UID = uid
+ return p
+}
+
+func (p *processBuilder) Process() *control.Process {
+ return &p.process
}
func procListToString(pl []*control.Process) string {
@@ -323,14 +359,7 @@ func TestLifecycle(t *testing.T) {
// expectedPL lists the expected process state of the container.
expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{1},
- },
+ newProcessBuilder().Cmd("sleep").Process(),
}
// Create the container.
args := Args{
@@ -608,10 +637,14 @@ func doAppExitStatus(t *testing.T, vfs2 bool) {
// TestExec verifies that a container can exec a new program.
func TestExec(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
- const uid = 343
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "exec-test")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir)
+ spec := testutil.NewSpecWithArgs("sh", "-c", cmd)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -634,29 +667,127 @@ func TestExec(t *testing.T) {
t.Fatalf("error starting container: %v", err)
}
- // expectedPL lists the expected process state of the container.
+ // Wait until sleep is running to ensure the symlink was created.
expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sh").Process(),
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("waitForProcessList: %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ args control.ExecArgs
+ }{
+ {
+ name: "complete",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ },
+ },
+ {
+ name: "argv",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename resolution",
+ args: control.ExecArgs{
+ Filename: "true",
+ Envv: []string{"PATH=/bin"},
+ },
+ },
+ {
+ name: "argv resolution",
+ args: control.ExecArgs{
+ Argv: []string{"true"},
+ Envv: []string{"PATH=/bin"},
+ },
+ },
{
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{1},
+ name: "argv symlink",
+ args: control.ExecArgs{
+ Argv: []string{filepath.Join(dir, "symlink")},
+ },
},
{
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{2},
+ name: "working dir",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${PWD}" != "/tmp" ]]; then exit 1; fi`},
+ WorkingDirectory: "/tmp",
+ },
},
+ {
+ name: "user",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -u)" != "343" ]]; then exit 1; fi`},
+ KUID: 343,
+ },
+ },
+ {
+ name: "group",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -g)" != "343" ]]; then exit 1; fi`},
+ KGID: 343,
+ },
+ },
+ {
+ name: "env",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${FOO}" != "123" ]]; then exit 1; fi`},
+ Envv: []string{"FOO=123"},
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // t.Parallel()
+ if ws, err := cont.executeSync(&tc.args); err != nil {
+ t.Fatalf("executeAsync(%+v): %v", tc.args, err)
+ } else if ws != 0 {
+ t.Fatalf("executeAsync(%+v) failed with exit: %v", tc.args, ws)
+ }
+ })
}
+ })
+ }
+}
- // Verify that "sleep 100" is running.
- if err := waitForProcessList(cont, expectedPL[:1]); err != nil {
- t.Error(err)
+// TestExecProcList verifies that a container can exec a new program and it
+// shows correcly in the process list.
+func TestExecProcList(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ const uid = 343
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
}
execArgs := &control.ExecArgs{
@@ -666,9 +797,8 @@ func TestExec(t *testing.T) {
KUID: uid,
}
- // Verify that "sleep 100" and "sleep 5" are running
- // after exec. First, start running exec (whick
- // blocks).
+ // Verify that "sleep 100" and "sleep 5" are running after exec. First,
+ // start running exec (which blocks).
ch := make(chan error)
go func() {
exitStatus, err := cont.executeSync(execArgs)
@@ -681,6 +811,11 @@ func TestExec(t *testing.T) {
}
}()
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").UID(0).Process(),
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").UID(uid).Process(),
+ }
if err := waitForProcessList(cont, expectedPL); err != nil {
t.Fatalf("error waiting for processes: %v", err)
}
@@ -1242,24 +1377,9 @@ func TestCapabilities(t *testing.T) {
// expectedPL lists the expected process state of the container.
expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{1},
- },
- {
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "exe",
- Threads: []kernel.ThreadID{2},
- },
+ newProcessBuilder().Cmd("sleep").Process(),
}
- if err := waitForProcessList(cont, expectedPL[:1]); err != nil {
+ if err := waitForProcessList(cont, expectedPL); err != nil {
t.Fatalf("Failed to wait for sleep to start, err: %v", err)
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 207206dd2..c2b54696c 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -149,13 +149,13 @@ func TestMultiContainerSanity(t *testing.T) {
// Check via ps that multiple processes are running.
expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
}
expectedPL = []*control.Process{
- {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -195,13 +195,13 @@ func TestMultiPIDNS(t *testing.T) {
// Check via ps that multiple processes are running.
expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
}
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -257,7 +257,7 @@ func TestMultiPIDNSPath(t *testing.T) {
// Check via ps that multiple processes are running.
expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -267,7 +267,7 @@ func TestMultiPIDNSPath(t *testing.T) {
}
expectedPL = []*control.Process{
- {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -300,7 +300,7 @@ func TestMultiContainerWait(t *testing.T) {
// Check via ps that multiple processes are running.
expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -345,7 +345,7 @@ func TestMultiContainerWait(t *testing.T) {
// After Wait returns, ensure that the root container is running and
// the child has finished.
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err)
@@ -377,7 +377,7 @@ func TestExecWait(t *testing.T) {
// Check via ps that process is running.
expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Fatalf("failed to wait for sleep to start: %v", err)
@@ -412,7 +412,7 @@ func TestExecWait(t *testing.T) {
// Wait for the exec'd process to exit.
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Fatalf("failed to wait for second container to stop: %v", err)
@@ -498,9 +498,8 @@ func TestMultiContainerSignal(t *testing.T) {
// Check via ps that container 1 process is running.
expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().Cmd("sleep").Process(),
}
-
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
}
@@ -512,7 +511,7 @@ func TestMultiContainerSignal(t *testing.T) {
// Make sure process 1 is still running.
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -626,8 +625,10 @@ func TestMultiContainerDestroy(t *testing.T) {
if err != nil {
t.Fatalf("error getting process data from sandbox: %v", err)
}
- expectedPL := []*control.Process{{PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}}}
- if r, err := procListsEqual(pss, expectedPL); !r {
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if !procListsEqual(pss, expectedPL) {
t.Errorf("container got process list: %s, want: %s: error: %v",
procListToString(pss), procListToString(expectedPL), err)
}
@@ -664,7 +665,7 @@ func TestMultiContainerProcesses(t *testing.T) {
// Check root's container process list doesn't include other containers.
expectedPL0 := []*control.Process{
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL0); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
@@ -672,8 +673,8 @@ func TestMultiContainerProcesses(t *testing.T) {
// Same for the other container.
expectedPL1 := []*control.Process{
- {PID: 2, Cmd: "sh", Threads: []kernel.ThreadID{2}},
- {PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}},
+ newProcessBuilder().PID(2).Cmd("sh").Process(),
+ newProcessBuilder().PID(3).PPID(2).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL1); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
@@ -687,7 +688,7 @@ func TestMultiContainerProcesses(t *testing.T) {
if _, err := containers[1].Execute(args); err != nil {
t.Fatalf("error exec'ing: %v", err)
}
- expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep", Threads: []kernel.ThreadID{4}})
+ expectedPL1 = append(expectedPL1, newProcessBuilder().PID(4).Cmd("sleep").Process())
if err := waitForProcessList(containers[1], expectedPL1); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
}
@@ -1505,7 +1506,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Ensure container is running
c := containers[2]
expectedPL := []*control.Process{
- {PID: 3, Cmd: "sleep", Threads: []kernel.ThreadID{3}},
+ newProcessBuilder().PID(3).Cmd("sleep").Process(),
}
if err := waitForProcessList(c, expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -1533,7 +1534,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
continue // container[2] has been killed.
}
pl := []*control.Process{
- {PID: kernel.ThreadID(i + 1), Cmd: "sleep", Threads: []kernel.ThreadID{kernel.ThreadID(i + 1)}},
+ newProcessBuilder().PID(kernel.ThreadID(i + 1)).Cmd("sleep").Process(),
}
if err := waitForProcessList(c, pl); err != nil {
t.Errorf("Container %q was affected by another container: %v", c.ID, err)
@@ -1553,7 +1554,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Wait until sandbox stops. waitForProcessList will loop until sandbox exits
// and RPC errors out.
impossiblePL := []*control.Process{
- {PID: 100, Cmd: "non-existent-process", Threads: []kernel.ThreadID{100}},
+ newProcessBuilder().Cmd("non-existent-process").Process(),
}
if err := waitForProcessList(c, impossiblePL); err == nil {
t.Fatalf("Sandbox was not killed after gofer death")
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
index 4fe1067d2..0d9a191b5 100755
--- a/scripts/common_build.sh
+++ b/scripts/common_build.sh
@@ -63,6 +63,10 @@ function run_as_root() {
bazel run --run_under="sudo" "${binary}" -- "$@"
}
+function query() {
+ QUERY_RESULT=$(bazel query "$@")
+}
+
function collect_logs() {
# Zip out everything into a convenient form.
if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh
index f0fc444c8..727503bce 100755
--- a/scripts/packetdrill_tests.sh
+++ b/scripts/packetdrill_tests.sh
@@ -19,4 +19,5 @@ source $(dirname $0)/common.sh
make load-packetdrill
install_runsc_for_test runsc-d
-test_runsc $(bazel query "attr(tags, manual, tests(//test/packetdrill/...))")
+query "attr(tags, manual, tests(//test/packetdrill/...))"
+test_runsc $QUERY_RESULT
diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh
index 17fc43f27..51c11f23f 100755
--- a/scripts/packetimpact_tests.sh
+++ b/scripts/packetimpact_tests.sh
@@ -19,4 +19,5 @@ source $(dirname $0)/common.sh
make load-packetimpact
install_runsc_for_test runsc-d
-test_runsc $(bazel query "attr(tags, packetimpact, tests(//test/packetimpact/...))")
+query "attr(tags, packetimpact, tests(//test/packetimpact/...))"
+test_runsc $QUERY_RESULT
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 9cbb2ed5b..91c956e10 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -239,7 +239,9 @@ func TestMemLimit(t *testing.T) {
d := dockerutil.MakeDocker(t)
defer d.CleanUp()
- allocMemory := 500 * 1024
+ // N.B. Because the size of the memory file may grow in large chunks,
+ // there is a minimum threshold of 1GB for the MemTotal figure.
+ allocMemory := 1024 * 1024
out, err := d.Run(dockerutil.RunOpts{
Image: "basic/alpine",
Memory: allocMemory, // In kB.
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 172ad9e16..38319a3b2 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -303,6 +303,10 @@ func TestNATRedirectRequiresProtocol(t *testing.T) {
singleTest(t, NATRedirectRequiresProtocol{})
}
+func TestNATLoopbackSkipsPrerouting(t *testing.T) {
+ singleTest(t, NATLoopbackSkipsPrerouting{})
+}
+
func TestInputSource(t *testing.T) {
singleTest(t, FilterInputSource{})
}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 0a10ce7fe..5e54a3963 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -39,6 +39,7 @@ func init() {
RegisterTestCase(NATOutDontRedirectIP{})
RegisterTestCase(NATOutRedirectInvert{})
RegisterTestCase(NATRedirectRequiresProtocol{})
+ RegisterTestCase(NATLoopbackSkipsPrerouting{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -326,32 +327,6 @@ func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error {
return nil
}
-// loopbackTests runs an iptables rule and ensures that packets sent to
-// dest:dropPort are received by localhost:acceptPort.
-func loopbackTest(dest net.IP, args ...string) error {
- if err := natTable(args...); err != nil {
- return err
- }
- sendCh := make(chan error)
- listenCh := make(chan error)
- go func() {
- sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
- }()
- go func() {
- listenCh <- listenUDP(acceptPort, sendloopDuration)
- }()
- select {
- case err := <-listenCh:
- if err != nil {
- return err
- }
- case <-time.After(sendloopDuration):
- return errors.New("timed out")
- }
- // sendCh will always take the full sendloop time.
- return <-sendCh
-}
-
// NATOutRedirectTCPPort tests that connections are redirected on specified ports.
type NATOutRedirectTCPPort struct{}
@@ -400,3 +375,65 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error {
func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error {
return nil
}
+
+// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't
+// affected by PREROUTING rules.
+type NATLoopbackSkipsPrerouting struct{}
+
+// Name implements TestCase.Name.
+func (NATLoopbackSkipsPrerouting) Name() string {
+ return "NATLoopbackSkipsPrerouting"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error {
+ // Redirect anything sent to localhost to an unused port.
+ dest := []byte{127, 0, 0, 1}
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+
+ // Establish a connection via localhost. If the PREROUTING rule did apply to
+ // loopback traffic, the connection would fail.
+ sendCh := make(chan error)
+ go func() {
+ sendCh <- connectTCP(dest, acceptPort, sendloopDuration)
+ }()
+
+ if err := listenTCP(acceptPort, sendloopDuration); err != nil {
+ return err
+ }
+ return <-sendCh
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// loopbackTests runs an iptables rule and ensures that packets sent to
+// dest:dropPort are received by localhost:acceptPort.
+func loopbackTest(dest net.IP, args ...string) error {
+ if err := natTable(args...); err != nil {
+ return err
+ }
+ sendCh := make(chan error)
+ listenCh := make(chan error)
+ go func() {
+ sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
+ }()
+ go func() {
+ listenCh <- listenUDP(acceptPort, sendloopDuration)
+ }()
+ select {
+ case err := <-listenCh:
+ if err != nil {
+ return err
+ }
+ case <-time.After(sendloopDuration):
+ return errors.New("timed out")
+ }
+ // sendCh will always take the full sendloop time.
+ return <-sendCh
+}
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index fb32964e9..8b4a4d905 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -43,16 +43,17 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
-func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
- fd, err = unix.Socket(domain, typ, 0)
+func pickPort(domain, typ int) (int, uint16, error) {
+ fd, err := unix.Socket(domain, typ, 0)
if err != nil {
- return -1, nil, err
+ return -1, 0, err
}
defer func() {
if err != nil {
err = multierr.Append(err, unix.Close(fd))
}
}()
+ var sa unix.Sockaddr
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
@@ -63,16 +64,20 @@ func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
sa = &sa6
default:
- return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
}
if err = unix.Bind(fd, sa); err != nil {
- return -1, nil, err
+ return -1, 0, err
}
sa, err = unix.Getsockname(fd)
if err != nil {
- return -1, nil, err
+ return -1, 0, err
}
- return fd, sa, nil
+ port, err := portFromSockaddr(sa)
+ if err != nil {
+ return -1, 0, err
+ }
+ return fd, port, nil
}
// layerState stores the state of a layer of a connection.
@@ -266,14 +271,10 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
}
// newTCPState creates a new TCPState.
-func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
- if err != nil {
- return nil, nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
+func newTCPState(domain int, out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
if err != nil {
- return nil, nil, err
+ return nil, err
}
s := tcpState{
out: TCP{SrcPort: &localPort},
@@ -283,12 +284,12 @@ func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) {
finSent: false,
}
if err := s.out.merge(&out); err != nil {
- return nil, nil, err
+ return nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, nil, err
+ return nil, err
}
- return &s, localAddr, nil
+ return &s, nil
}
func (s *tcpState) outgoing() Layer {
@@ -374,14 +375,10 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
-func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_DGRAM)
- if err != nil {
- return nil, nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
+func newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
- return nil, nil, err
+ return nil, err
}
s := udpState{
out: UDP{SrcPort: &localPort},
@@ -389,12 +386,12 @@ func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
portPickerFD: portPickerFD,
}
if err := s.out.merge(&out); err != nil {
- return nil, nil, err
+ return nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, nil, err
+ return nil, err
}
- return &s, localAddr, nil
+ return &s, nil
}
func (s *udpState) outgoing() Layer {
@@ -429,7 +426,6 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
- localAddr unix.Sockaddr
t *testing.T
}
@@ -475,20 +471,45 @@ func (conn *Connection) Close() {
}
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+// CreateFrame builds a frame for the connection with defaults overriden
+// from the innermost layer out, and additionalLayers added after it.
+//
+// Note that overrideLayers can have a length that is less than the number
+// of layers in this connection, and in such cases the innermost layers are
+// overriden first. As an example, valid values of overrideLayers for a TCP-
+// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
+// [Ethernet, IPv4, TCP].
+func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers {
var layersToSend Layers
- for _, s := range conn.layerStates {
- layersToSend = append(layersToSend, s.outgoing())
- }
- if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
+ for i, s := range conn.layerStates {
+ layer := s.outgoing()
+ // overrideLayers and conn.layerStates have their tails aligned, so
+ // to find the index we move backwards by the distance i is to the
+ // end.
+ if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
+ if err := layer.merge(overrideLayers[j]); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ }
+ }
+ layersToSend = append(layersToSend, layer)
}
layersToSend = append(layersToSend, additionalLayers...)
return layersToSend
}
+// SendFrameStateless sends a frame without updating any of the layer states.
+//
+// This method is useful for sending out-of-band control messages such as
+// ICMP packets, where it would not make sense to update the transport layer's
+// state using the ICMP header.
+func (conn *Connection) SendFrameStateless(frame Layers) {
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+}
+
// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *Connection) SendFrame(frame Layers) {
outBytes, err := frame.ToBytes()
@@ -509,10 +530,13 @@ func (conn *Connection) SendFrame(frame Layers) {
}
}
-// Send a packet with reasonable defaults. Potentially override the final layer
-// in the connection with the provided layer and add additionLayers.
-func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+// send sends a packet, possibly with layers of this connection overridden and
+// additional layers added.
+//
+// Types defined with Connection as the underlying type should expose
+// type-safe versions of this method.
+func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...))
}
// recvFrame gets the next successfully parsed frame (of type Layers) within the
@@ -606,7 +630,7 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- tcpState, localAddr, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
if err != nil {
t.Fatalf("can't make tcpState: %s", err)
}
@@ -623,20 +647,40 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
layerStates: []layerState{etherState, ipv4State, tcpState},
injector: injector,
sniffer: sniffer,
- localAddr: localAddr,
t: t,
}
}
-// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// Connect performs a TCP 3-way handshake. The input Connection should have a
// final TCP Layer.
-func (conn *TCPIPv4) Handshake() {
+func (conn *TCPIPv4) Connect() {
+ conn.t.Helper()
+
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if synAck == nil {
+ if err != nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
+// The input Connection should have a final TCP Layer.
+func (conn *TCPIPv4) ConnectWithOptions(options []byte) {
+ conn.t.Helper()
+
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
@@ -656,10 +700,35 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio
return (*Connection)(conn).ExpectFrame(expected, timeout)
}
+// ExpectNextData attempts to receive the next incoming segment for the
+// connection and expects that to match the given layers.
+//
+// It differs from ExpectData() in that here we are only interested in the next
+// received segment, while ExpectData() can receive multiple segments for the
+// connection until there is a match with given layers or a timeout.
+func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ // Receive the first incoming TCP segment for this connection.
+ got, err := conn.ExpectData(&TCP{}, nil, timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length()))
+ }
+ if !(*Connection)(conn).match(expected, got) {
+ return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
+ }
+ return got, nil
+}
+
// Send a packet with reasonable defaults. Potentially override the TCP layer in
// the connection with the provided layer and add additionLayers.
func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&tcp, additionalLayers...)
+ (*Connection)(conn).send(Layers{&tcp}, additionalLayers...)
}
// Close frees associated resources held by the TCPIPv4 connection.
@@ -681,32 +750,48 @@ func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
return gotTCP, err
}
-func (conn *TCPIPv4) state() *tcpState {
- state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+func (conn *TCPIPv4) tcpState() *tcpState {
+ state, ok := conn.layerStates[2].(*tcpState)
if !ok {
- conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *TCPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
}
return state
}
// RemoteSeqNum returns the next expected sequence number from the DUT.
func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
- return conn.state().remoteSeqNum
+ return conn.tcpState().remoteSeqNum
}
// LocalSeqNum returns the next sequence number to send from the testbench.
func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
- return conn.state().localSeqNum
+ return conn.tcpState().localSeqNum
}
// SynAck returns the SynAck that was part of the handshake.
func (conn *TCPIPv4) SynAck() *TCP {
- return conn.state().synAck
+ return conn.tcpState().synAck
}
// LocalAddr gets the local socket address of this connection.
-func (conn *TCPIPv4) LocalAddr() unix.Sockaddr {
- return conn.localAddr
+func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
}
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
@@ -740,15 +825,10 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
}
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *IPv6Conn) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
-}
-
-// CreateFrame builds a frame for the connection with ipv6 overriding the ipv6
-// layer defaults and additionalLayers added after it.
-func (conn *IPv6Conn) CreateFrame(ipv6 IPv6, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(&ipv6, additionalLayers...)
+// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
+// additionalLayers added after it.
+func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...)
}
// Close to clean up any resources held.
@@ -762,12 +842,6 @@ func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers,
return (*Connection)(conn).ExpectFrame(frame, timeout)
}
-// Drain drains the sniffer's receive buffer by receiving packets until there's
-// nothing else to receive.
-func (conn *TCPIPv4) Drain() {
- conn.sniffer.Drain()
-}
-
// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
type UDPIPv4 Connection
@@ -781,7 +855,7 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- udpState, localAddr, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
@@ -798,42 +872,43 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
- localAddr: localAddr,
t: t,
}
}
-// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv4) LocalAddr() unix.Sockaddr {
- return conn.localAddr
+func (conn *UDPIPv4) udpState() *udpState {
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+func (conn *UDPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
}
-// Send a packet with reasonable defaults. Potentially override the UDP layer in
-// the connection with the provided layer and add additionLayers.
-func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&udp, additionalLayers...)
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *UDPIPv4) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&udp}, additionalLayers...)
}
-// SendIP sends a packet with additionalLayers following the IP layer in the
-// connection.
-func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
- var layersToSend Layers
- for _, s := range conn.layerStates[:len(conn.layerStates)-1] {
- layersToSend = append(layersToSend, s.outgoing())
- }
- layersToSend = append(layersToSend, additionalLayers...)
- conn.SendFrame(layersToSend)
+// SendIP sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv4 headers and adding additionLayers.
+func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 1b0e5b8fc..560c4111b 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -939,6 +939,11 @@ func (l *Payload) ToBytes() ([]byte, error) {
return l.Bytes, nil
}
+// Length returns payload byte length.
+func (l *Payload) Length() int {
+ return l.length()
+}
+
func (l *Payload) match(other Layer) bool {
return equalLayer(l, other)
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 1598c61e9..2a41ef326 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -124,10 +124,8 @@ packetimpact_go_test(
)
packetimpact_go_test(
- name = "tcp_should_piggyback",
- srcs = ["tcp_should_piggyback_test.go"],
- # TODO(b/153680566): Fix netstack then remove the line below.
- expect_netstack_failure = True,
+ name = "tcp_send_window_sizes_piggyback",
+ srcs = ["tcp_send_window_sizes_piggyback_test.go"],
deps = [
"//pkg/tcpip/header",
"//test/packetimpact/testbench",
@@ -202,6 +200,26 @@ packetimpact_go_test(
)
packetimpact_go_test(
+ name = "tcp_splitseg_mss",
+ srcs = ["tcp_splitseg_mss_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_cork_mss",
+ srcs = ["tcp_cork_mss_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
name = "icmpv6_param_problem",
srcs = ["icmpv6_param_problem_test.go"],
# TODO(b/153485026): Fix netstack then remove the line below.
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index c26ab78d9..407565078 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -21,11 +21,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestFinWait2Timeout(t *testing.T) {
@@ -37,13 +37,13 @@ func TestFinWait2Timeout(t *testing.T) {
{"WithoutLinger2", false},
} {
t.Run(tt.description, func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
if tt.linger2 {
@@ -52,21 +52,21 @@ func TestFinWait2Timeout(t *testing.T) {
}
dut.Close(acceptFd)
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
time.Sleep(5 * time.Second)
conn.Drain()
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
if tt.linger2 {
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected a RST packet within a second but got none: %s", err)
}
} else {
- if got, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
+ if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
t.Fatalf("expected no RST packets within ten seconds but got one: %s", got)
}
}
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
index bb1fc26fc..4d1d9a7f5 100644
--- a/test/packetimpact/tests/icmpv6_param_problem_test.go
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -21,32 +21,32 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT
// should respond with an ICMPv6 Parameter Problem message.
func TestICMPv6ParamProblemTest(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
- conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{})
+ conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
defer conn.Close()
- ipv6 := tb.IPv6{
+ ipv6 := testbench.IPv6{
// 254 is reserved and used for experimentation and testing. This should
// cause an error.
- NextHeader: tb.Uint8(254),
+ NextHeader: testbench.Uint8(254),
}
- icmpv6 := tb.ICMPv6{
- Type: tb.ICMPv6Type(header.ICMPv6EchoRequest),
+ icmpv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
NDPPayload: []byte("hello world"),
}
- toSend := conn.CreateFrame(ipv6, &icmpv6)
- conn.SendFrame(toSend)
+ toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6)
+ (*testbench.Connection)(&conn).SendFrame(toSend)
// Build the expected ICMPv6 payload, which includes an index to the
// problematic byte and also the problematic packet as described in
@@ -61,14 +61,14 @@ func TestICMPv6ParamProblemTest(t *testing.T) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset)
expectedPayload = append(b, expectedPayload...)
- expectedICMPv6 := tb.ICMPv6{
- Type: tb.ICMPv6Type(header.ICMPv6ParamProblem),
+ expectedICMPv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem),
NDPPayload: expectedPayload,
}
- paramProblem := tb.Layers{
- &tb.Ether{},
- &tb.IPv6{},
+ paramProblem := testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
&expectedICMPv6,
}
timeout := time.Second
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
index 49e481d0b..4efb9829c 100644
--- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -24,14 +24,14 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
-func recvTCPSegment(conn *tb.TCPIPv4, expect *tb.TCP, expectPayload *tb.Payload) (uint16, error) {
+func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
layers, err := conn.ExpectData(expect, expectPayload, time.Second)
if err != nil {
return 0, fmt.Errorf("failed to receive TCP segment: %s", err)
@@ -39,7 +39,7 @@ func recvTCPSegment(conn *tb.TCPIPv4, expect *tb.TCP, expectPayload *tb.Payload)
if len(layers) < 2 {
return 0, fmt.Errorf("got packet with layers: %v, expected to have at least 2 layers (link and network)", layers)
}
- ipv4, ok := layers[1].(*tb.IPv4)
+ ipv4, ok := layers[1].(*testbench.IPv4)
if !ok {
return 0, fmt.Errorf("got network layer: %T, expected: *IPv4", layers[1])
}
@@ -56,16 +56,16 @@ func recvTCPSegment(conn *tb.TCPIPv4, expect *tb.TCP, expectPayload *tb.Payload)
// to force the DF bit to be 0, and checks that a retransmitted segment has a
// different IPv4 Identification value than the original segment.
func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFD)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
remoteFD, _ := dut.Accept(listenFD)
defer dut.Close(remoteFD)
@@ -83,18 +83,18 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
}
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
dut.Send(remoteFD, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err)
}
// Let the DUT estimate RTO with RTT from the DATA-ACK.
// TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
// we can skip sending this ACK.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- expectTCP := &tb.TCP{SeqNum: tb.Uint32(uint32(*conn.RemoteSeqNum()))}
+ expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))}
dut.Send(remoteFD, sampleData, 0)
originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload)
if err != nil {
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
index 70a22a2db..6e7ff41d7 100644
--- a/test/packetimpact/tests/tcp_close_wait_ack_test.go
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -23,17 +23,17 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestCloseWaitAck(t *testing.T) {
for _, tt := range []struct {
description string
- makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) tb.TCP
+ makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP
seqNumOffset seqnum.Size
expectAck bool
}{
@@ -45,27 +45,27 @@ func TestCloseWaitAck(t *testing.T) {
{"ACK", GenerateUnaccACKSegment, 2, true},
} {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
// Send a FIN to DUT to intiate the active close
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
- gotTCP, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
}
windowSize := seqnum.Size(*gotTCP.WindowSize)
// Send a segment with OTW Seq / unacc ACK and expect an ACK back
- conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &tb.Payload{Bytes: []byte("Sample Data")})
- gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if tt.expectAck && err != nil {
t.Fatalf("expected an ack but got none: %s", err)
}
@@ -75,14 +75,14 @@ func TestCloseWaitAck(t *testing.T) {
// Now let's verify DUT is indeed in CLOSE_WAIT
dut.Close(acceptFd)
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
t.Fatalf("expected DUT to send a FIN: %s", err)
}
// Ack the FIN from DUT
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
// Send some extra data to DUT
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")})
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected DUT to send an RST: %s", err)
}
})
@@ -92,17 +92,17 @@ func TestCloseWaitAck(t *testing.T) {
// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the
// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK
// is expected from the receiver.
-func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) tb.TCP {
+func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
lastAcceptable := conn.LocalSeqNum().Add(windowSize)
otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
- return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)}
+ return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
}
// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated
// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is
// expected from the receiver.
-func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) tb.TCP {
+func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
lastAcceptable := conn.RemoteSeqNum()
unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
- return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)}
+ return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
}
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
new file mode 100644
index 000000000..fb8f48629
--- /dev/null
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_cork_mss_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPCorkMSS tests for segment coalesce and split as per MSS.
+func TestTCPCorkMSS(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ const mss = uint32(header.TCPDefaultMSS)
+ options := make([]byte, header.TCPOptionMSSLength)
+ header.EncodeMSSOption(mss, options)
+ conn.ConnectWithOptions(options)
+
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1)
+
+ // Let the dut application send 2 small segments to be held up and coalesced
+ // until the application sends a larger segment to fill up to > MSS.
+ sampleData := []byte("Sample Data")
+ dut.Send(acceptFD, sampleData, 0)
+ dut.Send(acceptFD, sampleData, 0)
+
+ expectedData := sampleData
+ expectedData = append(expectedData, sampleData...)
+ largeData := make([]byte, mss+1)
+ expectedData = append(expectedData, largeData...)
+ dut.Send(acceptFD, largeData, 0)
+
+ // Expect the segments to be coalesced and sent and capped to MSS.
+ expectedPayload := testbench.Payload{Bytes: expectedData[:mss]}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the coalesced segment to be split and transmitted.
+ expectedPayload = testbench.Payload{Bytes: expectedData[mss:]}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Check for segments to *not* be held up because of TCP_CORK when
+ // the current send window is less than MSS.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
+ dut.Send(acceptFD, sampleData, 0)
+ dut.Send(acceptFD, sampleData, 0)
+ expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)}
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index 2c1ec27d3..b9b3e91d3 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -21,22 +21,22 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestTcpNoAcceptCloseReset(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
- conn.Handshake()
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn.Connect()
defer conn.Close()
dut.Close(listenFd)
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
t.Fatalf("expected a RST-ACK packet but got none: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index 351df193e..ad8c74234 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -23,11 +23,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
@@ -38,7 +38,7 @@ func TestTCPOutsideTheWindow(t *testing.T) {
for _, tt := range []struct {
description string
tcpFlags uint8
- payload []tb.Layer
+ payload []testbench.Layer
seqNumOffset seqnum.Size
expectACK bool
}{
@@ -46,28 +46,28 @@ func TestTCPOutsideTheWindow(t *testing.T) {
{"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
{"ACK", header.TCPFlagAck, nil, 0, false},
{"FIN", header.TCPFlagFin, nil, 0, false},
- {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 0, true},
{"SYN", header.TCPFlagSyn, nil, 1, true},
{"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
{"ACK", header.TCPFlagAck, nil, 1, true},
{"FIN", header.TCPFlagFin, nil, 1, false},
- {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 1, true},
{"SYN", header.TCPFlagSyn, nil, 2, true},
{"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
{"ACK", header.TCPFlagAck, nil, 2, true},
{"FIN", header.TCPFlagFin, nil, 2, false},
- {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 2, true},
} {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFD)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFD, _ := dut.Accept(listenFD)
defer dut.Close(acceptFD)
@@ -75,13 +75,13 @@ func TestTCPOutsideTheWindow(t *testing.T) {
conn.Drain()
// Ignore whatever incrementing that this out-of-order packet might cause
// to the AckNum.
- localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum()))
- conn.Send(tb.TCP{
- Flags: tb.Uint8(tt.tcpFlags),
- SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
+ localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(testbench.TCP{
+ Flags: testbench.Uint8(tt.tcpFlags),
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
}, tt.payload...)
timeout := 3 * time.Second
- gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
if tt.expectACK && err != nil {
t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
}
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
index 0a668adcf..55db4ece6 100644
--- a/test/packetimpact/tests/tcp_paws_mechanism_test.go
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -22,25 +22,25 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestPAWSMechanism(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFD)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
options := make([]byte, header.TCPOptionTSLength)
header.EncodeTSOption(currentTS(), 0, options)
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn), Options: options})
- synAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
+ synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("didn't get synack during handshake: %s", err)
}
@@ -50,7 +50,7 @@ func TestPAWSMechanism(t *testing.T) {
}
tsecr := parsedSynOpts.TSVal
header.EncodeTSOption(currentTS(), tsecr, options)
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
acceptFD, _ := dut.Accept(listenFD)
defer dut.Close(acceptFD)
@@ -61,9 +61,9 @@ func TestPAWSMechanism(t *testing.T) {
// every time we send one, it should not cause any flakiness because timestamps
// only need to be non-decreasing.
time.Sleep(3 * time.Millisecond)
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options}, &tb.Payload{Bytes: sampleData})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK but got none: %s", err)
}
@@ -86,9 +86,9 @@ func TestPAWSMechanism(t *testing.T) {
// 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness
// due to the exact same reasoning discussed above.
time.Sleep(3 * time.Millisecond)
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options}, &tb.Payload{Bytes: sampleData})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err = conn.Expect(tb.TCP{AckNum: lastAckNum, Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err)
}
diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
index 5cc93fa24..b640d8673 100644
--- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
+++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
@@ -28,19 +28,19 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestQueueReceiveInSynSent(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
- socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
sampleData := []byte("Sample Data")
@@ -49,7 +49,7 @@ func TestQueueReceiveInSynSent(t *testing.T) {
if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
@@ -77,11 +77,11 @@ func TestQueueReceiveInSynSent(t *testing.T) {
time.Sleep(time.Second)
// Bring the connection to Established.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
- if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+ if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
// Send sample data to DUT.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: sampleData})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
index c043ad881..e51409b66 100644
--- a/test/packetimpact/tests/tcp_retransmits_test.go
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -21,53 +21,53 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestRetransmits tests retransmits occur at exponentially increasing
// time intervals.
func TestRetransmits(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
defer dut.Close(acceptFd)
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
// Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
// TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
// we can skip sending this ACK.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
startRTO := time.Second
current := startRTO
first := time.Now()
dut.Send(acceptFd, sampleData, 0)
- seq := tb.Uint32(uint32(*conn.RemoteSeqNum()))
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
+ seq := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
// Expect retransmits of the same segment.
for i := 0; i < 5; i++ {
start := time.Now()
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
t.Fatalf("expected a packet with payload %v: %s loop %d", samplePayload, err, i)
}
if i == 0 {
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
new file mode 100644
index 000000000..90ab85419
--- /dev/null
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_send_window_sizes_piggyback_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestSendWindowSizesPiggyback tests cases where segment sizes are close to
+// sender window size and checks for ACK piggybacking for each of those case.
+func TestSendWindowSizesPiggyback(t *testing.T) {
+ sampleData := []byte("Sample Data")
+ segmentSize := uint16(len(sampleData))
+ // Advertise receive window sizes that are lesser, equal to or greater than
+ // enqueued segment size and check for segment transmits. The test attempts
+ // to enqueue a segment on the dut before acknowledging previous segment and
+ // lets the dut piggyback any ACKs along with the enqueued segment.
+ for _, tt := range []struct {
+ description string
+ windowSize uint16
+ expectedPayload1 []byte
+ expectedPayload2 []byte
+ enqueue bool
+ }{
+ // Expect the first segment to be split as it cannot be accomodated in
+ // the sender window. This means we need not enqueue a new segment after
+ // the first segment.
+ {"WindowSmallerThanSegment", segmentSize - 1, sampleData[:(segmentSize - 1)], sampleData[(segmentSize - 1):], false /* enqueue */},
+
+ {"WindowEqualToSegment", segmentSize, sampleData, sampleData, true /* enqueue */},
+
+ // Expect the second segment to not be split as its size is greater than
+ // the available sender window size. The segments should not be split
+ // when there is pending unacknowledged data and the segment-size is
+ // greater than available sender window.
+ {"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Connect()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+
+ dut.Send(acceptFd, sampleData, 0)
+ expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Expect any enqueued segment to be transmitted by the dut along with
+ // piggybacked ACK for our data.
+
+ if tt.enqueue {
+ // Enqueue a segment for the dut to transmit.
+ dut.Send(acceptFd, sampleData, 0)
+ }
+
+ // Send ACK for the previous segment along with data for the dut to
+ // receive and ACK back. Sending this ACK would make room for the dut
+ // to transmit any enqueued segment.
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
+
+ // Expect the dut to piggyback the ACK for received data along with
+ // the segment enqueued for transmit.
+ expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go
deleted file mode 100644
index 0240dc2f9..000000000
--- a/test/packetimpact/tests/tcp_should_piggyback_test.go
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_should_piggyback_test
-
-import (
- "flag"
- "testing"
- "time"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
-)
-
-func init() {
- tb.RegisterFlags(flag.CommandLine)
-}
-
-func TestPiggyback(t *testing.T) {
- dut := tb.NewDUT(t)
- defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort})
- defer conn.Close()
-
- conn.Handshake()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
-
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
-
- sampleData := []byte("Sample Data")
-
- dut.Send(acceptFd, sampleData, 0)
- expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
- expectedPayload := tb.Payload{Bytes: sampleData}
- if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
- t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
- }
-
- // Cause DUT to send us more data as soon as we ACK their first data segment because we have
- // a small window.
- dut.Send(acceptFd, sampleData, 0)
-
- // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of
- // sending a separate ACK packet.
- conn.Send(expectedTCP, &expectedPayload)
- if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
- t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
- }
-}
diff --git a/test/packetimpact/tests/tcp_splitseg_mss_test.go b/test/packetimpact/tests/tcp_splitseg_mss_test.go
new file mode 100644
index 000000000..9350d0988
--- /dev/null
+++ b/test/packetimpact/tests/tcp_splitseg_mss_test.go
@@ -0,0 +1,71 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_splitseg_mss_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPSplitSegMSS lets the dut try to send segments larger than MSS.
+// It tests if the transmitted segments are capped at MSS and are split.
+func TestTCPSplitSegMSS(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ const mss = uint32(header.TCPDefaultMSS)
+ options := make([]byte, header.TCPOptionMSSLength)
+ header.EncodeMSSOption(mss, options)
+ conn.ConnectWithOptions(options)
+
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ // Let the dut send a segment larger than MSS.
+ largeData := make([]byte, mss+1)
+ for i := 0; i < 2; i++ {
+ dut.Send(acceptFD, largeData, 0)
+ if i == 0 {
+ // On Linux, the initial segment goes out beyond MSS and the segment
+ // split occurs on retransmission. Call ExpectData to wait to
+ // receive the split segment.
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: largeData[:mss]}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ } else {
+ if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: largeData[:mss]}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: largeData[mss:]}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ }
+}
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
index e6ba84cab..7d5deab01 100644
--- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -21,32 +21,32 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED.
func TestTCPSynRcvdReset(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFD)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
// Expect dut connection to have transitioned to SYN-RCVD state.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s", err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
index ce31917ee..87e45d765 100644
--- a/test/packetimpact/tests/tcp_user_timeout_test.go
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -22,27 +22,27 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
-func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error {
+func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
sampleData := make([]byte, 100)
for i := range sampleData {
sampleData[i] = uint8(i)
}
conn.Drain()
dut.Send(fd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &tb.Payload{Bytes: sampleData}, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
return fmt.Errorf("expected data but got none: %w", err)
}
return nil
}
-func sendFIN(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error {
+func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
dut.Close(fd)
return nil
}
@@ -59,20 +59,20 @@ func TestTCPUserTimeout(t *testing.T) {
} {
for _, ttf := range []struct {
description string
- f func(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error
+ f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error
}{
{"AfterPayload", sendPayload},
{"AfterFIN", sendFIN},
} {
t.Run(tt.description+ttf.description, func(t *testing.T) {
// Create a socket, listen, TCP handshake, and accept.
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFD)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFD, _ := dut.Accept(listenFD)
if tt.userTimeout != 0 {
@@ -85,14 +85,14 @@ func TestTCPUserTimeout(t *testing.T) {
time.Sleep(tt.sendDelay)
conn.Drain()
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
// If TCP_USER_TIMEOUT was set and the above delay was longer than the
// TCP_USER_TIMEOUT then the DUT should send a RST in response to the
// testbench's packet.
expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
expectTimeout := 5 * time.Second
- got, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, expectTimeout)
+ got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
if expectRST && err != nil {
t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
index 58ec1d740..576ec1a8b 100644
--- a/test/packetimpact/tests/tcp_window_shrink_test.go
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -21,53 +21,53 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestWindowShrink(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
defer dut.Close(acceptFd)
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
dut.Send(acceptFd, sampleData, 0)
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
// We close our receiving window here
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
dut.Send(acceptFd, []byte("Sample Data"), 0)
// Note: There is another kind of zero-window probing which Windows uses (by sending one
// new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
// the following lines.
expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index dd43a24db..54cee138f 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -21,39 +21,39 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestZeroWindowProbeRetransmit tests retransmits of zero window probes
// to be sent at exponentially inreasing time intervals.
func TestZeroWindowProbeRetransmit(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
defer dut.Close(acceptFd)
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
// Send and receive sample data to the dut.
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %s", err)
}
@@ -63,9 +63,9 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
// of the recorded first zero probe transmission duration.
//
// Advertize zero receive window again.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
- probeSeq := tb.Uint32(uint32(*conn.RemoteSeqNum() - 1))
- ackProbe := tb.Uint32(uint32(*conn.RemoteSeqNum()))
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
startProbeDuration := time.Second
current := startProbeDuration
@@ -79,7 +79,7 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
// Expect zero-window probe with a timeout which is a function of the typical
// first retransmission time. The retransmission times is supposed to
// exponentially increase.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
t.Fatalf("expected a probe with sequence number %v: loop %d", probeSeq, i)
}
if i == 0 {
@@ -92,13 +92,14 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
t.Fatalf("zero probe came sooner interval %d probe %d\n", p, i)
}
// Acknowledge the zero-window probes from the dut.
- conn.Send(tb.TCP{AckNum: ackProbe, Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
current *= 2
}
// Advertize non-zero window.
- conn.Send(tb.TCP{AckNum: ackProbe, Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.
+ TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
index 6c453505d..c9b3b7af2 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -21,41 +21,41 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestZeroWindowProbe tests few cases of zero window probing over the
// same connection.
func TestZeroWindowProbe(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
defer dut.Close(acceptFd)
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
start := time.Now()
// Send and receive sample data to the dut.
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
sendTime := time.Now().Sub(start)
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %s", err)
}
@@ -63,16 +63,16 @@ func TestZeroWindowProbe(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
- probeSeq := tb.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
// Expected ack number of the ACK for the probe.
- ackProbe := tb.Uint32(uint32(*conn.RemoteSeqNum()))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
// Expect there are no zero-window probes sent until there is data to be sent out
// from the dut.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
t.Fatalf("unexpected a packet with sequence number %v: %s", probeSeq, err)
}
@@ -80,7 +80,7 @@ func TestZeroWindowProbe(t *testing.T) {
// Ask the dut to send out data.
dut.Send(acceptFd, sampleData, 0)
// Expect zero-window probe from the dut.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %v: %s", probeSeq, err)
}
// Expect the probe to be sent after some time. Compare against the previous
@@ -94,9 +94,9 @@ func TestZeroWindowProbe(t *testing.T) {
// and sends out the sample payload after the send window opens.
//
// Advertize non-zero window to the dut and ack the zero window probe.
- conn.Send(tb.TCP{AckNum: ackProbe, Flags: tb.Uint8(header.TCPFlagAck)})
+ conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
@@ -104,9 +104,9 @@ func TestZeroWindowProbe(t *testing.T) {
// Check if the dut responds as we do for a similar probe sent to it.
// Basically with sequence number to one byte behind the unacknowledged
// sequence number.
- p := tb.Uint32(uint32(*conn.LocalSeqNum()))
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), SeqNum: tb.Uint32(uint32(*conn.LocalSeqNum() - 1))})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
+ p := testbench.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with ack number: %d: %s", p, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
index 193427fb9..749281d9d 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -21,39 +21,39 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are
// retransmitting zero window probes.
func TestZeroWindowProbeUserTimeout(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close()
- conn.Handshake()
+ conn.Connect()
acceptFd, _ := dut.Accept(listenFd)
defer dut.Close(acceptFd)
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
- samplePayload := &tb.Payload{Bytes: sampleData}
+ samplePayload := &testbench.Payload{Bytes: sampleData}
// Send and receive sample data to the dut.
dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %s", err)
}
@@ -61,15 +61,15 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
- probeSeq := tb.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
start := time.Now()
// Ask the dut to send out data.
dut.Send(acceptFd, sampleData, 0)
// Expect zero-window probe from the dut.
- if _, err := conn.ExpectData(&tb.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %v: %s", probeSeq, err)
}
// Record the duration for first probe, the dut sends the zero window probe after
@@ -82,7 +82,7 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// Reduce the retransmit timeout.
dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
// Advertize zero window again.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Ask the dut to send out data that would trigger zero window probe retransmissions.
dut.Send(acceptFd, sampleData, 0)
@@ -91,8 +91,8 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// Expect the connection to have timed out and closed which would cause the dut
// to reply with a RST to the ACK we send.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index ca4df2ab0..b754918f6 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -26,11 +26,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
type connectionMode bool
@@ -59,12 +59,12 @@ func (e icmpError) String() string {
return "Unknown ICMP error"
}
-func (e icmpError) ToICMPv4() *tb.ICMPv4 {
+func (e icmpError) ToICMPv4() *testbench.ICMPv4 {
switch e {
case portUnreachable:
- return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4DstUnreachable), Code: tb.Uint8(header.ICMPv4PortUnreachable)}
+ return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4PortUnreachable)}
case timeToLiveExceeded:
- return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4TimeExceeded), Code: tb.Uint8(header.ICMPv4TTLExceeded)}
+ return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), Code: testbench.Uint8(header.ICMPv4TTLExceeded)}
}
return nil
}
@@ -76,8 +76,8 @@ type errorDetection struct {
}
type testData struct {
- dut *tb.DUT
- conn *tb.UDPIPv4
+ dut *testbench.DUT
+ conn *testbench.UDPIPv4
remoteFD int32
remotePort uint16
cleanFD int32
@@ -95,25 +95,26 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
}
// sendICMPError sends an ICMP error message in response to a UDP datagram.
-func sendICMPError(conn *tb.UDPIPv4, icmpErr icmpError, udp *tb.UDP) error {
+func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error {
+ layers := (*testbench.Connection)(conn).CreateFrame(nil)
+ layers = layers[:len(layers)-1]
+ ip, ok := udp.Prev().(*testbench.IPv4)
+ if !ok {
+ return fmt.Errorf("expected %s to be IPv4", udp.Prev())
+ }
if icmpErr == timeToLiveExceeded {
- ip, ok := udp.Prev().(*tb.IPv4)
- if !ok {
- return fmt.Errorf("expected %s to be IPv4", udp.Prev())
- }
*ip.TTL = 1
// Let serialization recalculate the checksum since we set the TTL
// to 1.
ip.Checksum = nil
-
- // Note that the ICMP payload is valid in this case because the UDP
- // payload is empty. If the UDP payload were not empty, the packet
- // length during serialization may not be calculated correctly,
- // resulting in a mal-formed packet.
- conn.SendIP(icmpErr.ToICMPv4(), ip, udp)
- } else {
- conn.SendIP(icmpErr.ToICMPv4(), udp.Prev(), udp)
}
+ // Note that the ICMP payload is valid in this case because the UDP
+ // payload is empty. If the UDP payload were not empty, the packet
+ // length during serialization may not be calculated correctly,
+ // resulting in a mal-formed packet.
+ layers = append(layers, icmpErr.ToICMPv4(), ip, udp)
+
+ (*testbench.Connection)(conn).SendFrameStateless(layers)
return nil
}
@@ -123,10 +124,10 @@ func sendICMPError(conn *tb.UDPIPv4, icmpErr icmpError, udp *tb.UDP) error {
// first recv should succeed immediately.
func testRecv(ctx context.Context, d testData) error {
// Check that receiving on the clean socket works.
- d.conn.Send(tb.UDP{DstPort: &d.cleanPort})
+ d.conn.Send(testbench.UDP{DstPort: &d.cleanPort})
d.dut.Recv(d.cleanFD, 100, 0)
- d.conn.Send(tb.UDP{})
+ d.conn.Send(testbench.UDP{})
if d.wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
@@ -151,7 +152,7 @@ func testRecv(ctx context.Context, d testData) error {
func testSendTo(ctx context.Context, d testData) error {
// Check that sending on the clean socket works.
d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(tb.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
+ if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err)
}
@@ -169,7 +170,7 @@ func testSendTo(ctx context.Context, d testData) error {
}
d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(tb.UDP{}, time.Second); err != nil {
+ if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
return fmt.Errorf("did not receive UDP packet as expected: %s", err)
}
return nil
@@ -187,7 +188,7 @@ func testSockOpt(_ context.Context, d testData) error {
// Check that after clearing socket error, sending doesn't fail.
d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(tb.UDP{}, time.Second); err != nil {
+ if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
return fmt.Errorf("did not receive UDP packet as expected: %s", err)
}
return nil
@@ -223,7 +224,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
errorDetection{"SockOpt", false, testSockOpt},
} {
t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
@@ -234,7 +235,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(cleanFD)
- conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
if connect {
@@ -243,7 +244,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
}
dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
- udp, err := conn.Expect(tb.UDP{}, time.Second)
+ udp, err := conn.Expect(testbench.UDP{}, time.Second)
if err != nil {
t.Fatalf("did not receive message from DUT: %s", err)
}
@@ -258,7 +259,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
// involved in the generation of the ICMP error. As such,
// interactions between it and the the DUT should be independent of
// the ICMP error at least at the port level.
- connClean := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer connClean.Close()
errDetectConn = &connClean
@@ -281,7 +282,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
wantErrno := wantErrno(connect, icmpErr)
t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
@@ -292,7 +293,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(cleanFD)
- conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
if connect {
@@ -301,7 +302,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
}
dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
- udp, err := conn.Expect(tb.UDP{}, time.Second)
+ udp, err := conn.Expect(testbench.UDP{}, time.Second)
if err != nil {
t.Fatalf("did not receive message from DUT: %s", err)
}
@@ -355,8 +356,8 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
t.Fatal(err)
}
- conn.Send(tb.UDP{DstPort: &cleanPort})
- conn.Send(tb.UDP{})
+ conn.Send(testbench.UDP{DstPort: &cleanPort})
+ conn.Send(testbench.UDP{})
wg.Wait()
})
}
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
index 0bae18ba3..77a9bfa1d 100644
--- a/test/packetimpact/tests/udp_recv_multicast_test.go
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -21,22 +21,20 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func TestUDPRecvMulticast(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(boundFD)
- conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
- frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")})
- frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))
- conn.SendFrame(frame)
+ conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{})
dut.Recv(boundFD, 100, 0)
}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index 350875a6f..a7db384ad 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -22,11 +22,11 @@ import (
"time"
"golang.org/x/sys/unix"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
func generateRandomPayload(t *testing.T, n int) string {
@@ -39,11 +39,11 @@ func generateRandomPayload(t *testing.T, n int) string {
}
func TestUDPRecv(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(boundFD)
- conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
testCases := []struct {
@@ -59,8 +59,7 @@ func TestUDPRecv(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte(tc.payload)})
- conn.SendFrame(frame)
+ conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: []byte(tc.payload)})
if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), tc.payload; got != want {
t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want)
}
@@ -69,11 +68,11 @@ func TestUDPRecv(t *testing.T) {
}
func TestUDPSend(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(boundFD)
- conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
testCases := []struct {
@@ -93,7 +92,7 @@ func TestUDPSend(t *testing.T) {
if got, want := int(dut.SendTo(boundFD, []byte(tc.payload), 0, conn.LocalAddr())), len(tc.payload); got != want {
t.Fatalf("short write got: %d, want: %d", got, want)
}
- if _, err := conn.ExpectData(tb.UDP{SrcPort: &remotePort}, tb.Payload{Bytes: []byte(tc.payload)}, 1*time.Second); err != nil {
+ if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: []byte(tc.payload)}, 1*time.Second); err != nil {
t.Fatal(err)
}
})
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 402ba4064..5a83f8060 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -93,6 +93,7 @@ def _syscall_test(
# we figure out how to request ipv4 sockets on Guitar machines.
if network == "host":
tags.append("noguitar")
+ tags.append("block-network")
# Disable off-host networking.
tags.append("requires-net:loopback")
diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD
index da1e331e1..f76e2ddc0 100644
--- a/test/runtimes/proctor/BUILD
+++ b/test/runtimes/proctor/BUILD
@@ -21,7 +21,7 @@ go_test(
size = "small",
srcs = ["proctor_test.go"],
library = ":proctor",
- nocgo = 1,
+ pure = True,
deps = [
"//pkg/test/testutil",
],
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 3406a2de8..d68afbe44 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -401,6 +401,14 @@ syscall_test(
)
syscall_test(
+ size = "medium",
+ # Takes too long under gotsan to run.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:ping_socket_test",
+ vfs2 = "True",
+)
+
+syscall_test(
size = "large",
add_overlay = True,
shard_count = 5,
@@ -700,6 +708,14 @@ syscall_test(
syscall_test(
size = "large",
shard_count = 50,
+ # Takes too long for TSAN. Creates a lot of TCP sockets.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test",
vfs2 = "True",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index f4b5de18d..6d051f48b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -802,10 +802,13 @@ cc_binary(
],
linkstatic = 1,
deps = [
+ ":socket_test_util",
"//test/util:file_descriptor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
gtest,
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
@@ -849,10 +852,7 @@ cc_binary(
cc_binary(
name = "fpsig_nested_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["fpsig_nested.cc"],
- arm64 = [],
- ),
+ srcs = ["fpsig_nested.cc"],
linkstatic = 1,
deps = [
gtest,
@@ -1412,6 +1412,21 @@ cc_binary(
)
cc_binary(
+ name = "ping_socket_test",
+ testonly = 1,
+ srcs = ["ping_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "pipe_test",
testonly = 1,
srcs = ["pipe.cc"],
@@ -2781,6 +2796,26 @@ cc_binary(
)
cc_binary(
+ name = "socket_inet_loopback_nogotsan_test",
+ testonly = 1,
+ srcs = ["socket_inet_loopback_nogotsan.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
name = "socket_netlink_test",
testonly = 1,
srcs = ["socket_netlink.cc"],
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index e08c578f0..f65a14fb8 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <stdio.h>
+#include <sys/socket.h>
#include <sys/un.h>
#include <algorithm>
@@ -141,6 +142,47 @@ TEST_P(AllSocketPairTest, Connect) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ConnectWithWrongType) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+ switch (type) {
+ case SOCK_STREAM:
+ type = SOCK_SEQPACKET;
+ break;
+ case SOCK_SEQPACKET:
+ type = SOCK_STREAM;
+ break;
+ }
+
+ const FileDescriptor another_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ if (sockets->first_addr()->sa_data[0] != 0) {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EPROTOTYPE));
+ } else {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ }
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
TEST_P(AllSocketPairTest, ConnectNonListening) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 12c9b05ca..e09afafe9 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -673,6 +673,33 @@ TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) {
EXPECT_EQ(execve_errno, ELOOP);
}
+TEST(ExecveatTest, UnshareFiles) {
+ TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755));
+ const FileDescriptor fd_closed_on_exec =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
+
+ pid_t child;
+ EXPECT_THAT(child = syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES,
+ 0, 0, 0, 0),
+ SyscallSucceeds());
+ if (child == 0) {
+ ExecveArray argv = {"test"};
+ ExecveArray envp;
+ ASSERT_THAT(
+ execve(RunfilePath(kBasicWorkload).c_str(), argv.get(), envp.get()),
+ SyscallSucceeds());
+ _exit(1);
+ }
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(status, 0);
+
+ struct stat st;
+ EXPECT_THAT(fstat(fd_closed_on_exec.get(), &st), SyscallSucceeds());
+}
+
TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) {
std::string parent_dir = "/tmp";
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 1a9f203b9..18d2f22c1 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -438,7 +438,12 @@ TEST(ElfTest, MissingText) {
ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
SyscallSucceedsWithValue(child));
// It runs off the end of the zeroes filling the end of the page.
+#if defined(__x86_64__)
EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status;
+#elif defined(__aarch64__)
+ // 0 is an invalid instruction opcode on arm64.
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGILL) << status;
+#endif
}
// Typical ELF with a data + bss segment
diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc
index 3ecb8db8e..638a93979 100644
--- a/test/syscalls/linux/flock.cc
+++ b/test/syscalls/linux/flock.cc
@@ -21,6 +21,7 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -34,11 +35,6 @@ namespace {
class FlockTest : public FileTest {};
-TEST_F(FlockTest, BadFD) {
- // EBADF: fd is not an open file descriptor.
- ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF));
-}
-
TEST_F(FlockTest, InvalidOpCombinations) {
// The operation cannot be both exclusive and shared.
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB),
@@ -57,15 +53,6 @@ TEST_F(FlockTest, NoOperationSpecified) {
SyscallFailsWithErrno(EINVAL));
}
-TEST(FlockTestNoFixture, FlockSupportsPipes) {
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
-
- EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds());
- EXPECT_THAT(close(fds[0]), SyscallSucceeds());
- EXPECT_THAT(close(fds[1]), SyscallSucceeds());
-}
-
TEST_F(FlockTest, TestSimpleExLock) {
// Test that we can obtain an exclusive lock (no other holders)
// and that we can unlock it.
@@ -583,6 +570,66 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) {
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
}
+TEST(FlockTestNoFixture, BadFD) {
+ // EBADF: fd is not an open file descriptor.
+ ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSymlink) {
+ // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH
+ // is supported.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockProc) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockPipe) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds());
+ // Check that the pipe was locked above.
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EAGAIN));
+
+ EXPECT_THAT(flock(fds[0], LOCK_UN), SyscallSucceeds());
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallSucceeds());
+
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSocket) {
+ int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX));
+ ASSERT_THAT(
+ bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(flock(sock, LOCK_EX | LOCK_NB), SyscallSucceeds());
+ EXPECT_THAT(close(sock), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc
index d08111cd3..c47567b4e 100644
--- a/test/syscalls/linux/fpsig_fork.cc
+++ b/test/syscalls/linux/fpsig_fork.cc
@@ -37,11 +37,11 @@ namespace {
#define __stringify_1(x...) #x
#define __stringify(x...) __stringify_1(x)
#define GET_FPREG(var, regname) \
- asm volatile("str "__stringify(regname) ", %0" : "=m"(var))
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
#define SET_FPREG(var, regname) \
- asm volatile("ldr "__stringify(regname) ", %0" : "=m"(var))
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
#define GET_FP0(var) GET_FPREG(var, d0)
-#define SET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
#endif
int parent, child;
diff --git a/test/syscalls/linux/fpsig_nested.cc b/test/syscalls/linux/fpsig_nested.cc
index c476a8e7a..302d928d1 100644
--- a/test/syscalls/linux/fpsig_nested.cc
+++ b/test/syscalls/linux/fpsig_nested.cc
@@ -26,9 +26,22 @@ namespace testing {
namespace {
+#ifdef __x86_64__
#define GET_XMM(__var, __xmm) \
asm volatile("movq %%" #__xmm ", %0" : "=r"(__var))
#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var))
+#define GET_FP0(__var) GET_XMM(__var, xmm0)
+#define SET_FP0(__var) SET_XMM(__var, xmm0)
+#elif __aarch64__
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+#define GET_FPREG(var, regname) \
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
+#define SET_FPREG(var, regname) \
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
+#define GET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
+#endif
int pid;
int tid;
@@ -40,20 +53,21 @@ void sigusr2(int s, siginfo_t* siginfo, void* _uc) {
uint64_t val = SIGUSR2;
// Record the value of %xmm0 on entry and then clobber it.
- GET_XMM(entryxmm[1], xmm0);
- SET_XMM(val, xmm0);
- GET_XMM(exitxmm[1], xmm0);
+ GET_FP0(entryxmm[1]);
+ SET_FP0(val);
+ GET_FP0(exitxmm[1]);
}
void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
uint64_t val = SIGUSR1;
// Record the value of %xmm0 on entry and then clobber it.
- GET_XMM(entryxmm[0], xmm0);
- SET_XMM(val, xmm0);
+ GET_FP0(entryxmm[0]);
+ SET_FP0(val);
// Send a SIGUSR2 to ourself. The signal mask is configured such that
// the SIGUSR2 handler will run before this handler returns.
+#ifdef __x86_64__
asm volatile(
"movl %[killnr], %%eax;"
"movl %[pid], %%edi;"
@@ -66,10 +80,19 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
: "rax", "rdi", "rsi", "rdx",
// Clobbered by syscall.
"rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR2));
+#endif
// Record value of %xmm0 again to verify that the nested signal handler
// does not clobber it.
- GET_XMM(exitxmm[0], xmm0);
+ GET_FP0(exitxmm[0]);
}
TEST(FPSigTest, NestedSignals) {
@@ -98,8 +121,9 @@ TEST(FPSigTest, NestedSignals) {
// to signal the current thread ensures that this is the clobbered thread.
uint64_t expected = 0xdeadbeeffacefeed;
- SET_XMM(expected, xmm0);
+ SET_FP0(expected);
+#ifdef __x86_64__
asm volatile(
"movl %[killnr], %%eax;"
"movl %[pid], %%edi;"
@@ -112,9 +136,18 @@ TEST(FPSigTest, NestedSignals) {
: "rax", "rdi", "rsi", "rdx",
// Clobbered by syscall.
"rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR1));
+#endif
uint64_t got;
- GET_XMM(got, xmm0);
+ GET_FP0(got);
//
// The checks below verifies the following:
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index 640fe6bfc..670c0284b 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -416,6 +416,29 @@ TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) {
EXPECT_EQ(stat.st_size, 0);
}
+TEST_F(OpenTest, CanTruncateWithStrangePermissions) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ const DisableSave ds; // Permissions are dropped.
+ std::string path = NewTempAbsPath();
+ int fd;
+ // Create a file without user permissions.
+ EXPECT_THAT( // SAVE_BELOW
+ fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ // Cannot open file because we are owner and have no permissions set.
+ EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // We *can* chmod the file, because we are the owner.
+ EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds());
+
+ // Now we can open the file again.
+ EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
new file mode 100644
index 000000000..a9bfdb37b
--- /dev/null
+++ b/test/syscalls/linux/ping_socket.cc
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class PingSocket : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void PingSocket::SetUp() {
+ // On some hosts ping sockets are restricted to specific groups using the
+ // sysctl "ping_group_range".
+ int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
+ if (s < 0 && errno == EPERM) {
+ GTEST_SKIP();
+ }
+ close(s);
+
+ addr_ = {};
+ // Just a random port as the destination port number is irrelevant for ping
+ // sockets.
+ addr_.sin_port = 12345;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void PingSocket::TearDown() {}
+
+// Test ICMP port exhaustion returns EAGAIN.
+//
+// We disable both random/cooperative S/R for this test as it makes way too many
+// syscalls.
+TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+ DisableSave ds;
+ std::vector<FileDescriptor> sockets;
+ constexpr int kSockets = 65536;
+ addr_.sin_port = 0;
+ for (int i = 0; i < kSockets; i++) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+ int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_));
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 1e35a4a8b..7a316427d 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -259,9 +259,9 @@ TEST_F(PollTest, Nfds) {
TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0);
// gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE.
- constexpr rlim_t gVisorMax = 1048576;
- if (rlim.rlim_cur > gVisorMax) {
- rlim.rlim_cur = gVisorMax;
+ constexpr rlim_t maxFD = 4096;
+ if (rlim.rlim_cur > maxFD) {
+ rlim.rlim_cur = maxFD;
TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0);
}
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
index 1967329ee..4b2cb0606 100644
--- a/test/syscalls/linux/socket_bind_to_device_sequence.cc
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -363,8 +363,9 @@ TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
}
TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
+ // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets.
+ SKIP_IF(IsRunningOnGvisor() &&
+ (GetParam().type & SOCK_STREAM) == SOCK_STREAM);
ASSERT_NO_FATAL_FAILURE(
BindSocket(/* reusePort */ false, /* reuse_addr */ true));
@@ -405,8 +406,9 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) {
}
TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
+ // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets.
+ SKIP_IF(IsRunningOnGvisor() &&
+ (GetParam().type & SOCK_STREAM) == SOCK_STREAM);
ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
/* reuse_addr */ true,
@@ -434,8 +436,9 @@ TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) {
}
TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
+ // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets.
+ SKIP_IF(IsRunningOnGvisor() &&
+ (GetParam().type & SOCK_STREAM) == SOCK_STREAM);
ASSERT_NO_FATAL_FAILURE(BindSocket(
/* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0));
@@ -469,10 +472,20 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) {
/* bind_to_device */ 0, EADDRINUSE));
}
-// This behavior seems like a bug?
TEST_P(BindToDeviceSequenceTest,
BindReuseAddrThenReuseAddrReusePortThenReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
+ // The behavior described in this test seems like a Linux bug. It doesn't
+ // make any sense and it is unlikely that any applications rely on it.
+ //
+ // Both SO_REUSEADDR and SO_REUSEPORT allow binding multiple UDP sockets to
+ // the same address and deliver each packet to exactly one of the bound
+ // sockets. If both are enabled, one of the strategies is selected to route
+ // packets. The strategy is selected dynamically based on the settings of the
+ // currently bound sockets. Usually, the strategy is selected based on the
+ // common setting (SO_REUSEADDR or SO_REUSEPORT) amongst the sockets, but for
+ // some reason, Linux allows binding sets of sockets with no overlapping
+ // settings in some situations. In this case, it is not obvious which strategy
+ // would be selected as the configured setting is a contradiction.
SKIP_IF(IsRunningOnGvisor());
ASSERT_NO_FATAL_FAILURE(BindSocket(
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index fa890ec98..a914d9a12 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1900,9 +1900,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
auto const& param = GetParam();
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM);
-
// Bind the v6 loopback on a dual stack socket.
TestAddress const& test_addr = V6Loopback();
sockaddr_storage bound_addr = test_addr.addr;
@@ -2091,9 +2088,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
V4MappedEphemeralPortReservedResueAddr) {
auto const& param = GetParam();
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM);
-
// Bind the v4 loopback on a dual stack socket.
TestAddress const& test_addr = V4MappedLoopback();
sockaddr_storage bound_addr = test_addr.addr;
@@ -2283,9 +2277,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
auto const& param = GetParam();
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM);
-
// Bind the v4 loopback on a v4 socket.
TestAddress const& test_addr = V4Loopback();
sockaddr_storage bound_addr = test_addr.addr;
diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
new file mode 100644
index 000000000..2324c7f6a
--- /dev/null
+++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <string.h>
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Gt;
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+struct TestParam {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) {
+ return absl::StrCat("Listen", info.param.listener.description, "_Connect",
+ info.param.connector.description);
+}
+
+using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>;
+
+// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral
+// ports are already in use for a given destination ip/port.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 10;
+ constexpr int kClients = 65536;
+
+ // Create the listening socket.
+ auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+
+ // Now we keep opening connections till we run out of local ephemeral ports.
+ // and assert the error we get back.
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ std::vector<FileDescriptor> servers;
+
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret == 0) {
+ clients.push_back(std::move(client));
+ FileDescriptor server =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ servers.push_back(std::move(server));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL));
+ break;
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetLoopbackTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Any(), V4MappedAny()},
+ TestParam{V4Any(), V4MappedLoopback()},
+ TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+ TestParam{V4MappedAny(), V4Any()},
+ TestParam{V4MappedAny(), V4Loopback()},
+ TestParam{V4MappedAny(), V4MappedAny()},
+ TestParam{V4MappedAny(), V4MappedLoopback()},
+ TestParam{V4MappedLoopback(), V4Any()},
+ TestParam{V4MappedLoopback(), V4Loopback()},
+ TestParam{V4MappedLoopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()},
+ TestParam{V6Any(), V4MappedAny()},
+ TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()},
+ TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Any()},
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index fa81845fd..15adc8d0e 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -524,6 +524,7 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) {
// Copied from include/net/tcp.h.
constexpr int MAX_TCP_KEEPIDLE = 32767;
constexpr int MAX_TCP_KEEPINTVL = 32767;
+constexpr int MAX_TCP_KEEPCNT = 127;
TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -575,6 +576,78 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) {
EXPECT_EQ(get, MAX_TCP_KEEPINTVL);
}
+TEST_P(TCPSocketPairTest, TCPKeepcountDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 9); // 9 keepalive probes.
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero,
+ sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPCNT);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, keepaliveCount);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = -5;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST_P(TCPSocketPairTest, SetOOBInline) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 1c533fdf2..edb86aded 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -225,9 +225,6 @@ TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
TEST_P(UDPSocketPairTest, SetReuseAddr) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
@@ -292,9 +289,6 @@ TEST_P(UDPSocketPairTest, SetReusePort) {
TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
index ca597e267..5536ab53a 100644
--- a/test/syscalls/linux/socket_ip_unbound.cc
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -157,7 +157,7 @@ TEST_P(IPUnboundSocketTest, TOSDefault) {
int get = -1;
socklen_t get_sz = sizeof(get);
constexpr int kDefaultTOS = 0;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, kDefaultTOS);
@@ -173,7 +173,7 @@ TEST_P(IPUnboundSocketTest, SetTOS) {
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, set);
@@ -188,7 +188,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOS) {
SyscallSucceedsWithValue(0));
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, set);
@@ -210,7 +210,7 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, kDefaultTOS);
@@ -229,7 +229,7 @@ TEST_P(IPUnboundSocketTest, CheckSkipECN) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
@@ -249,7 +249,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) {
}
int get = -1;
socklen_t get_sz = 0;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, 0);
EXPECT_EQ(get, -1);
@@ -276,7 +276,7 @@ TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) {
}
uint get = -1;
socklen_t get_sz = i;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, expect_sz);
// Account for partial copies by getsockopt, retrieve the lower
@@ -297,7 +297,7 @@ TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) {
// We expect the system call handler to only copy atmost sizeof(int) bytes
// as asserted by the check below. Hence, we do not expect the copy to
// overflow in getsockopt.
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(int));
EXPECT_EQ(get, set);
@@ -325,7 +325,7 @@ TEST_P(IPUnboundSocketTest, NegativeTOS) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
@@ -338,20 +338,20 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) {
TOSOption t = GetTOSOption(GetParam().domain);
int expect;
if (GetParam().domain == AF_INET) {
- EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
SyscallSucceedsWithValue(0));
expect = static_cast<uint8_t>(set);
if (GetParam().protocol == IPPROTO_TCP) {
expect &= ~INET_ECN_MASK;
}
} else {
- EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
SyscallFailsWithErrno(EINVAL));
expect = 0;
}
int get = 0;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
@@ -377,8 +377,11 @@ TEST_P(IPUnboundSocketTest, NullTOS) {
//
// Linux's implementation would need fixing as passing a nullptr as optval
// and non-zero optlen may not be valid.
- EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
- SyscallSucceedsWithValue(0));
+ // TODO(b/158666797): Combine the gVisor and linux cases for IPv6.
+ // Some kernel versions return EFAULT, so we handle both.
+ EXPECT_THAT(
+ setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(0)));
}
}
socklen_t get_sz = sizeof(int);
@@ -419,6 +422,38 @@ TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) {
EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL));
}
+TEST_P(IPUnboundSocketTest, ReuseAddrDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // FIXME(b/76031995): gVisor TCP sockets default to SO_REUSEADDR enabled.
+ SKIP_IF(IsRunningOnGvisor() &&
+ (GetParam().type & SOCK_STREAM) == SOCK_STREAM);
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+TEST_P(IPUnboundSocketTest, SetReuseAddr) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ASSERT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
INSTANTIATE_TEST_SUITE_P(
IPUnboundSockets, IPUnboundSocketTest,
::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index bc4b07a62..4db9553e1 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -1672,10 +1672,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
}
// Check that SO_REUSEADDR always delivers to the most recently bound socket.
-TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
+//
+// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable
+// random and co-op save (below) once that is fixed.
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
std::vector<std::unique_ptr<FileDescriptor>> sockets;
sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
@@ -1696,6 +1696,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) {
constexpr int kMessageSize = 200;
+ // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly.
+ const DisableSave ds;
+
for (int i = 0; i < 10; i++) {
// Add a new receiver.
sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
@@ -1836,9 +1839,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) {
}
TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1880,9 +1880,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
}
TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1927,9 +1924,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
}
TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -2020,9 +2014,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) {
}
TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
- // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets.
- SKIP_IF(IsRunningOnGvisor());
-
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -2129,6 +2120,39 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(kMessageSize));
}
+// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ constexpr int kClients = 65536;
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+ std::vector<std::unique_ptr<FileDescriptor>> sockets;
+ for (int i = 0; i < kClients; i++) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len);
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
// Test that socket will receive packet info control message.
TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
// TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc
index 84d3a569e..6d03df4d9 100644
--- a/test/syscalls/linux/socket_unix_seqpacket.cc
+++ b/test/syscalls/linux/socket_unix_seqpacket.cc
@@ -43,6 +43,24 @@ TEST_P(SeqpacketUnixSocketPairTest, ReadOneSideClosed) {
SyscallSucceedsWithValue(0));
}
+TEST_P(SeqpacketUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallSucceedsWithValue(3));
+
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->first_fd(), data, sizeof(data)),
+ SyscallSucceedsWithValue(3));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 563467365..99e77b89e 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -89,6 +89,20 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) {
SyscallFailsWithErrno(ECONNRESET));
}
+TEST_P(StreamUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallFailsWithErrno(EISCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index 6195b11e1..97d554e72 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -398,5 +398,25 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
}
}
+// Write hang bug found by syskaller: b/155928773
+// https://syzkaller.appspot.com/bug?id=065b893bd8d1d04a4e0a1d53c578537cde1efe99
+TEST_F(TuntapTest, WriteHangBug155928773) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.1", &remote.sin_addr);
+ // Return values do not matter in this test.
+ connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote));
+ write(sock, "hello", 5);
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 41eded16d..0bd468786 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -96,6 +96,7 @@ def go_imports(name, src, out):
cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
)
+# TODO(b/158696872): Enable nogo by default.
def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, nogo = False, **kwargs):
"""Wraps the standard go_library and does stateification and marshalling.
diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go
index 96c965ea2..eec93534b 100644
--- a/tools/go_generics/globals/scope.go
+++ b/tools/go_generics/globals/scope.go
@@ -72,6 +72,10 @@ func (s *scope) deepLookup(n string) *symbol {
}
func (s *scope) add(name string, kind SymKind, pos token.Pos) {
+ if s.syms[name] != nil {
+ return
+ }
+
s.syms[name] = &symbol{
kind: kind,
pos: pos,