summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml10
-rw-r--r--g3doc/user_guide/FAQ.md7
-rw-r--r--g3doc/user_guide/containerd/configuration.md25
-rw-r--r--g3doc/user_guide/containerd/quick_start.md5
-rw-r--r--images/basic/fsstress/Dockerfile.x86_64 (renamed from images/basic/fsstress/Dockerfile)2
-rw-r--r--pkg/abi/linux/fcntl.go3
-rw-r--r--pkg/abi/linux/socket.go12
-rw-r--r--pkg/abi/linux/tcp.go9
-rw-r--r--pkg/coverage/coverage.go18
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go4
-rw-r--r--pkg/sentry/fs/fs.go2
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go4
-rw-r--r--pkg/sentry/fs/host/socket.go2
-rw-r--r--pkg/sentry/fs/lock/BUILD2
-rw-r--r--pkg/sentry/fs/lock/lock.go153
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go5
-rw-r--r--pkg/sentry/fs/lock/lock_test.go124
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go11
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go11
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go11
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go11
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go14
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go12
-rw-r--r--pkg/sentry/fsimpl/host/host.go11
-rw-r--r--pkg/sentry/fsimpl/host/socket.go2
-rw-r--r--pkg/sentry/fsimpl/host/tty.go11
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go11
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go11
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go11
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go21
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go18
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go11
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go24
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go4
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/fd_table.go12
-rw-r--r--pkg/sentry/kernel/kernel.go4
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go11
-rw-r--r--pkg/sentry/kernel/task_block.go2
-rw-r--r--pkg/sentry/kernel/threads.go12
-rw-r--r--pkg/sentry/mm/procfs.go7
-rw-r--r--pkg/sentry/mm/syscalls.go27
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go12
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go3
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go36
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go4
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go11
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go30
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go31
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go9
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go11
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go265
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go25
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go8
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go2
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go87
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go11
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go78
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go39
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go13
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go4
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/file_description.go84
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go30
-rw-r--r--pkg/sentry/vfs/lock.go47
-rw-r--r--pkg/sentry/vfs/opath.go139
-rw-r--r--pkg/sentry/vfs/options.go2
-rw-r--r--pkg/sentry/vfs/vfs.go12
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/syserr/netstack.go189
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go85
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go18
-rw-r--r--pkg/tcpip/checker/checker.go4
-rw-r--r--pkg/tcpip/errors.go538
-rw-r--r--pkg/tcpip/faketime/BUILD5
-rw-r--r--pkg/tcpip/faketime/faketime.go318
-rw-r--r--pkg/tcpip/header/ipv4.go58
-rw-r--r--pkg/tcpip/header/ipv6.go4
-rw-r--r--pkg/tcpip/header/ipv6_test.go8
-rw-r--r--pkg/tcpip/link/channel/channel.go4
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go4
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go16
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go125
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go4
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go260
-rw-r--r--pkg/tcpip/link/loopback/loopback.go4
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/nested/nested.go4
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go34
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go17
-rw-r--r--pkg/tcpip/link/rawfile/BUILD1
-rw-r--r--pkg/tcpip/link/rawfile/errors.go87
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go15
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go19
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go4
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go4
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD18
-rw-r--r--pkg/tcpip/network/arp/arp.go146
-rw-r--r--pkg/tcpip/network/arp/arp_test.go210
-rw-r--r--pkg/tcpip/network/arp/stats.go70
-rw-r--r--pkg/tcpip/network/arp/stats_test.go93
-rw-r--r--pkg/tcpip/network/ip/BUILD5
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go4
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go4
-rw-r--r--pkg/tcpip/network/ip/stats.go100
-rw-r--r--pkg/tcpip/network/ip_test.go290
-rw-r--r--pkg/tcpip/network/ipv4/BUILD13
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go103
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go60
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go328
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go134
-rw-r--r--pkg/tcpip/network/ipv4/stats.go190
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go99
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go205
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go135
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go217
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go118
-rw-r--r--pkg/tcpip/network/ipv6/mld.go24
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go27
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go239
-rw-r--r--pkg/tcpip/network/ipv6/stats.go132
-rw-r--r--pkg/tcpip/network/testutil/BUILD2
-rw-r--r--pkg/tcpip/network/testutil/testutil.go76
-rw-r--r--pkg/tcpip/network/testutil/testutil_unsafe.go26
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go16
-rw-r--r--pkg/tcpip/ports/ports_test.go145
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go36
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go11
-rw-r--r--pkg/tcpip/socketops.go108
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go34
-rw-r--r--pkg/tcpip/stack/conntrack.go12
-rw-r--r--pkg/tcpip/stack/forwarding_test.go78
-rw-r--r--pkg/tcpip/stack/iptables.go30
-rw-r--r--pkg/tcpip/stack/iptables_types.go64
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go145
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go111
-rw-r--r--pkg/tcpip/stack/ndp_test.go186
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go13
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go313
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go13
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go342
-rw-r--r--pkg/tcpip/stack/nic.go173
-rw-r--r--pkg/tcpip/stack/nic_test.go37
-rw-r--r--pkg/tcpip/stack/nud.go4
-rw-r--r--pkg/tcpip/stack/nud_test.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go241
-rw-r--r--pkg/tcpip/stack/registration.go80
-rw-r--r--pkg/tcpip/stack/route.go124
-rw-r--r--pkg/tcpip/stack/stack.go264
-rw-r--r--pkg/tcpip/stack/stack_options.go32
-rw-r--r--pkg/tcpip/stack/stack_test.go419
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go28
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go8
-rw-r--r--pkg/tcpip/stack/transport_test.go89
-rw-r--r--pkg/tcpip/tcpip.go436
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go321
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go336
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go778
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go35
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go17
-rw-r--r--pkg/tcpip/tests/integration/route_test.go39
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go143
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go18
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go121
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go22
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go165
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go7
-rw-r--r--pkg/tcpip/transport/raw/protocol.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go18
-rw-r--r--pkg/tcpip/transport/tcp/connect.go127
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go22
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go426
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go54
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go26
-rw-r--r--pkg/tcpip/transport/tcp/rack.go154
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go76
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go38
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go84
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go442
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go13
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go201
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go25
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go119
-rw-r--r--pkg/usermem/usermem.go20
-rw-r--r--runsc/container/container_test.go59
-rw-r--r--runsc/container/multi_container_test.go24
-rw-r--r--runsc/fsgofer/BUILD1
-rw-r--r--runsc/fsgofer/fsgofer.go80
-rw-r--r--runsc/fsgofer/fsgofer_amd64_unsafe.go3
-rw-r--r--runsc/fsgofer/fsgofer_arm64_unsafe.go3
-rw-r--r--runsc/fsgofer/fsgofer_test.go94
-rw-r--r--runsc/mitigate/BUILD20
-rw-r--r--runsc/mitigate/cpu.go235
-rw-r--r--runsc/mitigate/cpu_test.go368
-rw-r--r--runsc/mitigate/mitigate.go20
-rw-r--r--test/fsstress/fsstress_test.go2
-rw-r--r--test/iptables/BUILD2
-rw-r--r--test/iptables/filter_input.go198
-rw-r--r--test/iptables/iptables_test.go32
-rw-r--r--test/iptables/iptables_util.go2
-rw-r--r--test/iptables/nat.go265
-rw-r--r--test/packetimpact/dut/BUILD2
-rw-r--r--test/packetimpact/dut/posix_server.cc40
-rw-r--r--test/packetimpact/proto/posix_server.proto23
-rw-r--r--test/packetimpact/runner/defs.bzl10
-rw-r--r--test/packetimpact/testbench/connections.go40
-rw-r--r--test/packetimpact/testbench/dut.go52
-rw-r--r--test/packetimpact/tests/BUILD41
-rw-r--r--test/packetimpact/tests/ipv4_fragment_reassembly_test.go12
-rw-r--r--test/packetimpact/tests/tcp_info_test.go103
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go15
-rw-r--r--test/packetimpact/tests/tcp_rack_test.go221
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go28
-rw-r--r--test/packetimpact/tests/udp_recv_mcast_bcast_test.go115
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go338
-rw-r--r--test/syscalls/BUILD2
-rw-r--r--test/syscalls/linux/BUILD10
-rw-r--r--test/syscalls/linux/chmod.cc36
-rw-r--r--test/syscalls/linux/chown.cc38
-rw-r--r--test/syscalls/linux/dup.cc85
-rw-r--r--test/syscalls/linux/fadvise64.cc11
-rw-r--r--test/syscalls/linux/fallocate.cc7
-rw-r--r--test/syscalls/linux/fchdir.cc12
-rw-r--r--test/syscalls/linux/fcntl.cc292
-rw-r--r--test/syscalls/linux/getdents.cc26
-rw-r--r--test/syscalls/linux/inotify.cc94
-rw-r--r--test/syscalls/linux/ioctl.cc13
-rw-r--r--test/syscalls/linux/link.cc55
-rw-r--r--test/syscalls/linux/madvise.cc12
-rw-r--r--test/syscalls/linux/mmap.cc12
-rw-r--r--test/syscalls/linux/open.cc22
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc8
-rw-r--r--test/syscalls/linux/ping_socket.cc55
-rw-r--r--test/syscalls/linux/pread64.cc10
-rw-r--r--test/syscalls/linux/preadv.cc14
-rw-r--r--test/syscalls/linux/preadv2.cc18
-rw-r--r--test/syscalls/linux/pty.cc11
-rw-r--r--test/syscalls/linux/pwrite64.cc11
-rw-r--r--test/syscalls/linux/pwritev2.cc17
-rw-r--r--test/syscalls/linux/raw_socket.cc8
-rw-r--r--test/syscalls/linux/read.cc9
-rw-r--r--test/syscalls/linux/readv.cc14
-rw-r--r--test/syscalls/linux/shm.cc36
-rw-r--r--test/syscalls/linux/sigreturn_amd64.cc (renamed from test/syscalls/linux/sigiret.cc)0
-rw-r--r--test/syscalls/linux/sigreturn_arm64.cc97
-rw-r--r--test/syscalls/linux/socket.cc75
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc27
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc6
-rw-r--r--test/syscalls/linux/socket_unix_cmsg.cc2
-rw-r--r--test/syscalls/linux/stat.cc37
-rw-r--r--test/syscalls/linux/statfs.cc10
-rw-r--r--test/syscalls/linux/symlink.cc30
-rw-r--r--test/syscalls/linux/sync.cc12
-rw-r--r--test/syscalls/linux/tcp_socket.cc23
-rw-r--r--test/syscalls/linux/truncate.cc10
-rw-r--r--test/syscalls/linux/write.cc38
-rw-r--r--test/syscalls/linux/xattr.cc21
-rw-r--r--test/util/BUILD1
-rw-r--r--test/util/fs_util.cc16
-rw-r--r--test/util/fs_util.h2
-rw-r--r--test/util/logging.cc8
-rw-r--r--test/util/logging.h39
-rw-r--r--test/util/multiprocess_util.cc3
-rw-r--r--test/util/multiprocess_util.h3
-rw-r--r--test/util/posix_error.cc2
-rw-r--r--test/util/posix_error.h14
-rw-r--r--tools/bazel.mk2
-rw-r--r--website/cmd/syscalldocs/main.go34
306 files changed, 12496 insertions, 5601 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index cddf5504b..cb272aef6 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -45,6 +45,14 @@ steps:
- make BAZEL_OPTIONS=--config=cross-aarch64 artifacts/aarch64
- make release
+ # Images tests.
+ - <<: *common
+ label: ":docker: Images (x86_64)"
+ command: make ARCH=x86_64 load-all-images
+ - <<: *common
+ label: ":docker: Images (aarch64)"
+ command: make ARCH=aarch64 load-all-images
+
# Basic unit tests.
- <<: *common
label: ":test_tube: Unit tests"
@@ -192,7 +200,7 @@ steps:
command: make benchmark-platforms BENCHMARKS_SUITE=node BENCHMARKS_TARGETS=test/benchmarks/network:node_test
- <<: *benchmarks
label: ":redis: Redis benchmarks"
- command: make benchmark-platforms BENCHMARKS_SUITE=redis BENCHMARKS_TARGETS=test/benchmarks/database:redis_test
+ command: make benchmark-platforms BENCHMARKS_SUITE=redis BENCHMARKS_TARGETS=test/benchmarks/database:redis_test BENCHMARKS_OPTIONS=-test.benchtime=15s
- <<: *benchmarks
label: ":ruby: Ruby benchmarks"
command: make benchmark-platforms BENCHMARKS_SUITE=ruby BENCHMARKS_TARGETS=test/benchmarks/network:ruby_test
diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md
index 69033357c..8e5721ad1 100644
--- a/g3doc/user_guide/FAQ.md
+++ b/g3doc/user_guide/FAQ.md
@@ -137,9 +137,16 @@ sandbox isolation. There are a few different workarounds you can try:
* Use IPs instead of container names.
* Use [Kubernetes][k8s]. Container name lookup works fine in Kubernetes.
+### I'm getting an error like `dial unix /run/containerd/s/09e4...8cff: connect: connection refused: unknown` {#shim-connect}
+
+This error may happen when using `gvisor-containerd-shim` with a `containerd`
+that does not contain the fix for [CVE-2020-15257]. The resolve the issue,
+update containerd to 1.3.9 or 1.4.3 (or newer versions respectively).
+
[security-model]: /docs/architecture_guide/security/
[host-net]: /docs/user_guide/networking/#network-passthrough
[debugging]: /docs/user_guide/debugging/
[filesystem]: /docs/user_guide/filesystem/
[docker]: /docs/user_guide/quick_start/docker/
[k8s]: /docs/user_guide/quick_start/kubernetes/
+[CVE-2020-15257]: https://github.com/containerd/containerd/security/advisories/GHSA-36xw-fx78-c5r4
diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md
index 4f5e721be..011af3b10 100644
--- a/g3doc/user_guide/containerd/configuration.md
+++ b/g3doc/user_guide/containerd/configuration.md
@@ -1,8 +1,8 @@
# Containerd Advanced Configuration
This document describes how to configure runtime options for
-`containerd-shim-runsc-v1`. This follows the
-[Containerd Quick Start](./quick_start.md) and requires containerd 1.2 or later.
+`containerd-shim-runsc-v1`. You can find the installation instructions and
+minimal requirements in [Containerd Quick Start](./quick_start.md).
## Shim Configuration
@@ -47,27 +47,6 @@ When you are done, restart containerd to pick up the changes.
sudo systemctl restart containerd
```
-### Containerd 1.2
-
-For containerd 1.2, the config file is not configurable. It should be named
-`config.toml` and located in the runtime root. By default, this is
-`/run/containerd/runsc`.
-
-### Example: Enable the KVM platform
-
-gVisor enables the use of a number of platforms. This example shows how to
-configure `containerd-shim-runsc-v1` to use gvisor with the KVM platform.
-
-Find out more about platform in the
-[Platforms Guide](../../architecture_guide/platforms.md).
-
-```shell
-cat <<EOF | sudo tee /etc/containerd/runsc.toml
-[runsc_config]
- platform = "kvm"
-EOF
-```
-
## Debug
When `shim_debug` is enabled in `/etc/containerd/config.toml`, containerd will
diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md
index a98fe5c4a..132d80927 100644
--- a/g3doc/user_guide/containerd/quick_start.md
+++ b/g3doc/user_guide/containerd/quick_start.md
@@ -1,7 +1,7 @@
# Containerd Quick Start
This document describes how to use `containerd-shim-runsc-v1` with the
-containerd runtime handler support on `containerd` 1.2 or later.
+containerd runtime handler support on `containerd`.
> ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you
> may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details.
@@ -11,7 +11,8 @@ containerd runtime handler support on `containerd` 1.2 or later.
- **runsc** and **containerd-shim-runsc-v1**: See the
[installation guide](/docs/user_guide/install/).
- **containerd**: See the [containerd website](https://containerd.io/) for
- information on how to install containerd.
+ information on how to install containerd. **Minimal version supported: 1.3.9
+ or 1.4.3.**
## Configure containerd
diff --git a/images/basic/fsstress/Dockerfile b/images/basic/fsstress/Dockerfile.x86_64
index 84604ead5..21b86065a 100644
--- a/images/basic/fsstress/Dockerfile
+++ b/images/basic/fsstress/Dockerfile.x86_64
@@ -1,7 +1,7 @@
# Usage: docker run --rm fsstress -d /test -n 10000 -p 100 -X -v
FROM alpine
-RUN apk add git
+RUN apk update && apk add git
RUN git clone https://github.com/linux-test-project/ltp.git --depth 1
WORKDIR /ltp
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index d1ca56370..b84d7c048 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -21,6 +21,7 @@ const (
F_SETFD = 2
F_GETFL = 3
F_SETFL = 4
+ F_GETLK = 5
F_SETLK = 6
F_SETLKW = 7
F_SETOWN = 8
@@ -55,7 +56,7 @@ type Flock struct {
_ [4]byte
Start int64
Len int64
- Pid int32
+ PID int32
_ [4]byte
}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 8591acbf2..cb33c37bd 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -416,6 +416,18 @@ type TCPInfo struct {
RwndLimited uint64
// SndBufLimited is the time in microseconds limited by send buffer.
SndBufLimited uint64
+
+ Delievered uint32
+ DelieveredCe uint32
+
+ // BytesSent is RFC4898 tcpEStatsPerfHCDataOctetsOut.
+ BytesSent uint64
+ // BytesRetrans is RFC4898 tcpEStatsPerfOctetsRetrans.
+ BytesRetrans uint64
+ // DSACKDups is RFC4898 tcpEStatsStackDSACKDups.
+ DSACKDups uint32
+ // ReordSeen is the number of reordering events seen.
+ ReordSeen uint32
}
// SizeOfTCPInfo is the binary size of a TCPInfo struct.
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
index 2a8d4708b..1a3c0916f 100644
--- a/pkg/abi/linux/tcp.go
+++ b/pkg/abi/linux/tcp.go
@@ -59,3 +59,12 @@ const (
MAX_TCP_KEEPINTVL = 32767
MAX_TCP_KEEPCNT = 127
)
+
+// Congestion control states from include/uapi/linux/tcp.h.
+const (
+ TCP_CA_Open = 0
+ TCP_CA_Disorder = 1
+ TCP_CA_CWR = 2
+ TCP_CA_Recovery = 3
+ TCP_CA_Loss = 4
+)
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index fdfe31417..6f3d72e83 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -26,7 +26,6 @@ import (
"fmt"
"io"
"sort"
- "sync/atomic"
"testing"
"gvisor.dev/gvisor/pkg/sync"
@@ -69,12 +68,18 @@ var globalData struct {
}
// ClearCoverageData clears existing coverage data.
+//
+//go:norace
func ClearCoverageData() {
coverageMu.Lock()
defer coverageMu.Unlock()
+
+ // We do not use atomic operations while reading/writing to the counters,
+ // which would drastically degrade performance. Slight discrepancies due to
+ // racing is okay for the purposes of kcov.
for _, counters := range coverdata.Cover.Counters {
for index := 0; index < len(counters); index++ {
- atomic.StoreUint32(&counters[index], 0)
+ counters[index] = 0
}
}
}
@@ -114,6 +119,8 @@ var coveragePool = sync.Pool{
// ensure that each event is only reported once. Due to the limitations of Go
// coverage tools, we reset the global coverage data every time this function is
// run.
+//
+//go:norace
func ConsumeCoverageData(w io.Writer) int {
InitCoverageData()
@@ -125,11 +132,14 @@ func ConsumeCoverageData(w io.Writer) int {
for fileNum, file := range globalData.files {
counters := coverdata.Cover.Counters[file]
for index := 0; index < len(counters); index++ {
- if atomic.LoadUint32(&counters[index]) == 0 {
+ // We do not use atomic operations while reading/writing to the counters,
+ // which would drastically degrade performance. Slight discrepancies due to
+ // racing is okay for the purposes of kcov.
+ if counters[index] == 0 {
continue
}
// Non-zero coverage data found; consume it and report as a PC.
- atomic.StoreUint32(&counters[index], 0)
+ counters[index] = 0
pc := globalData.syntheticPCs[fileNum][index]
usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
n, err := w.Write(pcBuffer[:])
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
index af8230a7d..387f713aa 100644
--- a/pkg/sentry/fs/fdpipe/pipe_state.go
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -34,7 +34,9 @@ func (p *pipeOperations) beforeSave() {
} else if p.flags.Write {
file, err := p.opener.NonBlockingOpen(context.Background(), fs.PermMask{Write: true})
if err != nil {
- panic(fs.ErrSaveRejection{fmt.Errorf("write-only pipe end cannot be re-opened as %v: %v", p, err)})
+ panic(&fs.ErrSaveRejection{
+ Err: fmt.Errorf("write-only pipe end cannot be re-opened as %#v: %w", p, err),
+ })
}
file.Close()
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index a020da53b..44587bb37 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -144,7 +144,7 @@ type ErrSaveRejection struct {
}
// Error returns a sensible description of the save rejection error.
-func (e ErrSaveRejection) Error() string {
+func (e *ErrSaveRejection) Error() string {
return "save rejected due to unsupported file system state: " + e.Err.Error()
}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index a3402e343..141e3c27f 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -67,7 +67,9 @@ func (i *inodeFileState) beforeSave() {
if i.sattr.Type == fs.RegularFile {
uattr, err := i.unstableAttr(&dummyClockContext{context.Background()})
if err != nil {
- panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.s.inodeMappings[i.sattr.InodeID], err)})
+ panic(&fs.ErrSaveRejection{
+ Err: fmt.Errorf("failed to get unstable atttribute of %s: %w", i.s.inodeMappings[i.sattr.InodeID], err),
+ })
}
i.savedUAttr = &uattr
}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index a2f3d5918..07b4fb70f 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -257,7 +257,7 @@ func (c *ConnectedEndpoint) Passcred() bool {
}
// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
-func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index ae3331737..4d3b216d8 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -41,6 +41,8 @@ go_library(
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/sync",
"//pkg/waiter",
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 8a5d9c7eb..57686ce07 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -54,6 +54,8 @@ import (
"math"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -83,6 +85,17 @@ const (
// offset 0 to LockEOF.
const LockEOF = math.MaxUint64
+// OwnerInfo describes the owner of a lock.
+//
+// TODO(gvisor.dev/issue/5264): We may need to add other fields in the future
+// (e.g., Linux's file_lock.fl_flags to support open file-descriptor locks).
+//
+// +stateify savable
+type OwnerInfo struct {
+ // PID is the process ID of the lock owner.
+ PID int32
+}
+
// Lock is a regional file lock. It consists of either a single writer
// or a set of readers.
//
@@ -92,14 +105,20 @@ const LockEOF = math.MaxUint64
// A Lock may be downgraded from a write lock to a read lock only if
// the write lock's uid is the same as the read lock.
//
+// Accesses to Lock are synchronized through the Locks object to which it
+// belongs.
+//
// +stateify savable
type Lock struct {
// Readers are the set of read lock holders identified by UniqueID.
- // If len(Readers) > 0 then HasWriter must be false.
- Readers map[UniqueID]bool
+ // If len(Readers) > 0 then Writer must be nil.
+ Readers map[UniqueID]OwnerInfo
// Writer holds the writer unique ID. It's nil if there are no writers.
Writer UniqueID
+
+ // WriterInfo describes the writer. It is only meaningful if Writer != nil.
+ WriterInfo OwnerInfo
}
// Locks is a thread-safe wrapper around a LockSet.
@@ -135,14 +154,14 @@ const (
// acquiring the lock in a non-blocking mode or "interrupted" if in a blocking mode.
// Blocker is the interface used to provide blocking behavior, passing a nil Blocker
// will result in non-blocking behavior.
-func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) bool {
+func (l *Locks) LockRegion(uid UniqueID, ownerPID int32, t LockType, r LockRange, block Blocker) bool {
for {
l.mu.Lock()
// Blocking locks must run in a loop because we'll be woken up whenever an unlock event
// happens for this lock. We will then attempt to take the lock again and if it fails
// continue blocking.
- res := l.locks.lock(uid, t, r)
+ res := l.locks.lock(uid, ownerPID, t, r)
if !res && block != nil {
e, ch := waiter.NewChannelEntry(nil)
l.blockedQueue.EventRegister(&e, EventMaskAll)
@@ -161,6 +180,14 @@ func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker)
}
}
+// LockRegionVFS1 is a wrapper around LockRegion for VFS1, which does not implement
+// F_GETLK (and does not care about storing PIDs as a result).
+//
+// TODO(gvisor.dev/issue/1624): Delete.
+func (l *Locks) LockRegionVFS1(uid UniqueID, t LockType, r LockRange, block Blocker) bool {
+ return l.LockRegion(uid, 0 /* ownerPID */, t, r, block)
+}
+
// UnlockRegion attempts to release a lock for the uid on a region of a file.
// This operation is always successful, even if there did not exist a lock on
// the requested region held by uid in the first place.
@@ -175,13 +202,14 @@ func (l *Locks) UnlockRegion(uid UniqueID, r LockRange) {
// makeLock returns a new typed Lock that has either uid as its only reader
// or uid as its only writer.
-func makeLock(uid UniqueID, t LockType) Lock {
- value := Lock{Readers: make(map[UniqueID]bool)}
+func makeLock(uid UniqueID, ownerPID int32, t LockType) Lock {
+ value := Lock{Readers: make(map[UniqueID]OwnerInfo)}
switch t {
case ReadLock:
- value.Readers[uid] = true
+ value.Readers[uid] = OwnerInfo{PID: ownerPID}
case WriteLock:
value.Writer = uid
+ value.WriterInfo = OwnerInfo{PID: ownerPID}
default:
panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
}
@@ -190,17 +218,20 @@ func makeLock(uid UniqueID, t LockType) Lock {
// isHeld returns true if uid is a holder of Lock.
func (l Lock) isHeld(uid UniqueID) bool {
- return l.Writer == uid || l.Readers[uid]
+ if _, ok := l.Readers[uid]; ok {
+ return true
+ }
+ return l.Writer == uid
}
// lock sets uid as a holder of a typed lock on Lock.
//
// Preconditions: canLock is true for the range containing this Lock.
-func (l *Lock) lock(uid UniqueID, t LockType) {
+func (l *Lock) lock(uid UniqueID, ownerPID int32, t LockType) {
switch t {
case ReadLock:
// If we are already a reader, then this is a no-op.
- if l.Readers[uid] {
+ if _, ok := l.Readers[uid]; ok {
return
}
// We cannot downgrade a write lock to a read lock unless the
@@ -210,11 +241,11 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
}
// Ensure that there is only one reader if upgrading.
- l.Readers = make(map[UniqueID]bool)
+ l.Readers = make(map[UniqueID]OwnerInfo)
// Ensure that there is no longer a writer.
l.Writer = nil
}
- l.Readers[uid] = true
+ l.Readers[uid] = OwnerInfo{PID: ownerPID}
return
case WriteLock:
// If we are already the writer, then this is a no-op.
@@ -228,13 +259,14 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
if readers != 1 {
panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, too many readers %v", uid, l.Readers))
}
- if !l.Readers[uid] {
+ if _, ok := l.Readers[uid]; !ok {
panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, conflicting reader %v", uid, l.Readers))
}
}
// Ensure that there is only a writer.
- l.Readers = make(map[UniqueID]bool)
+ l.Readers = make(map[UniqueID]OwnerInfo)
l.Writer = uid
+ l.WriterInfo = OwnerInfo{PID: ownerPID}
default:
panic(fmt.Sprintf("lock: invalid lock type %d", t))
}
@@ -247,7 +279,7 @@ func (l LockSet) lockable(r LockRange, check func(value Lock) bool) bool {
// Get our starting point.
seg := l.LowerBoundSegment(r.Start)
for seg.Ok() && seg.Start() < r.End {
- // Note that we don't care about overruning the end of the
+ // Note that we don't care about overrunning the end of the
// last segment because if everything checks out we'll just
// split the last segment.
if !check(seg.Value()) {
@@ -281,7 +313,7 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
if value.Writer == nil {
// Then this uid can only take a write lock if this is a private
// upgrade, meaning that the only reader is uid.
- return len(value.Readers) == 1 && value.Readers[uid]
+ return value.isOnlyReader(uid)
}
// If the uid is already a writer on this region, then
// adding a write lock would be a no-op.
@@ -292,11 +324,19 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
}
}
+func (l *Lock) isOnlyReader(uid UniqueID) bool {
+ if len(l.Readers) != 1 {
+ return false
+ }
+ _, ok := l.Readers[uid]
+ return ok
+}
+
// lock returns true if uid took a lock of type t on the entire range of
// LockRange.
//
// Preconditions: r.Start <= r.End (will panic otherwise).
-func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
+func (l *LockSet) lock(uid UniqueID, ownerPID int32, t LockType, r LockRange) bool {
if r.Start > r.End {
panic(fmt.Sprintf("lock: r.Start %d > r.End %d", r.Start, r.End))
}
@@ -317,7 +357,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
seg, gap := l.Find(r.Start)
if gap.Ok() {
// Fill in the gap and get the next segment to modify.
- seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, t)).NextSegment()
+ seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, ownerPID, t)).NextSegment()
} else if seg.Start() < r.Start {
// Get our first segment to modify.
_, seg = l.Split(seg, r.Start)
@@ -331,12 +371,12 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
// Set the lock on the segment. This is guaranteed to
// always be safe, given canLock above.
value := seg.ValuePtr()
- value.lock(uid, t)
+ value.lock(uid, ownerPID, t)
// Fill subsequent gaps.
gap = seg.NextGap()
if gr := gap.Range().Intersect(r); gr.Length() > 0 {
- seg = l.Insert(gap, gr, makeLock(uid, t)).NextSegment()
+ seg = l.Insert(gap, gr, makeLock(uid, ownerPID, t)).NextSegment()
} else {
seg = gap.NextSegment()
}
@@ -380,7 +420,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
// only ever be one writer and no readers, then this
// lock should always be removed from the set.
remove = true
- } else if value.Readers[uid] {
+ } else if _, ok := value.Readers[uid]; ok {
// If uid is the last reader, then just remove the entire
// segment.
if len(value.Readers) == 1 {
@@ -390,7 +430,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
// affecting any other segment's readers. To do
// this, we need to make a copy of the Readers map
// and not add this uid.
- newValue := Lock{Readers: make(map[UniqueID]bool)}
+ newValue := Lock{Readers: make(map[UniqueID]OwnerInfo)}
for k, v := range value.Readers {
if k != uid {
newValue.Readers[k] = v
@@ -451,3 +491,72 @@ func ComputeRange(start, length, offset int64) (LockRange, error) {
// Offset is guaranteed to be positive at this point.
return LockRange{Start: uint64(offset), End: end}, nil
}
+
+// TestRegion checks whether the lock holder identified by uid can hold a lock
+// of type t on range r. It returns a Flock struct representing this
+// information as the F_GETLK fcntl does.
+//
+// Note that the PID returned in the flock structure is relative to the root PID
+// namespace. It needs to be converted to the caller's PID namespace before
+// returning to userspace.
+//
+// TODO(gvisor.dev/issue/5264): we don't support OFD locks through fcntl, which
+// would return a struct with pid = -1.
+func (l *Locks) TestRegion(ctx context.Context, uid UniqueID, t LockType, r LockRange) linux.Flock {
+ f := linux.Flock{Type: linux.F_UNLCK}
+ switch t {
+ case ReadLock:
+ l.testRegion(r, func(lock Lock, start, length uint64) bool {
+ if lock.Writer == nil || lock.Writer == uid {
+ return true
+ }
+ f.Type = linux.F_WRLCK
+ f.PID = lock.WriterInfo.PID
+ f.Start = int64(start)
+ f.Len = int64(length)
+ return false
+ })
+ case WriteLock:
+ l.testRegion(r, func(lock Lock, start, length uint64) bool {
+ if lock.Writer == nil {
+ for k, v := range lock.Readers {
+ if k != uid {
+ // Stop at the first conflict detected.
+ f.Type = linux.F_RDLCK
+ f.PID = v.PID
+ f.Start = int64(start)
+ f.Len = int64(length)
+ return false
+ }
+ }
+ return true
+ }
+ if lock.Writer == uid {
+ return true
+ }
+ f.Type = linux.F_WRLCK
+ f.PID = lock.WriterInfo.PID
+ f.Start = int64(start)
+ f.Len = int64(length)
+ return false
+ })
+ default:
+ panic(fmt.Sprintf("TestRegion: invalid lock type %d", t))
+ }
+ return f
+}
+
+func (l *Locks) testRegion(r LockRange, check func(lock Lock, start, length uint64) bool) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ seg := l.locks.LowerBoundSegment(r.Start)
+ for seg.Ok() && seg.Start() < r.End {
+ lock := seg.Value()
+ if !check(lock, seg.Start(), seg.End()-seg.Start()) {
+ // Stop at the first conflict detected.
+ return
+ }
+ seg = seg.NextSegment()
+ }
+}
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
index 50a16e662..dcc17c0dc 100644
--- a/pkg/sentry/fs/lock/lock_set_functions.go
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -40,7 +40,7 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
return Lock{}, false
}
for k := range val1.Readers {
- if !val2.Readers[k] {
+ if _, ok := val2.Readers[k]; !ok {
return Lock{}, false
}
}
@@ -53,11 +53,12 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) {
// Copy the segment so that split segments don't contain map references
// to other segments.
- val0 := Lock{Readers: make(map[UniqueID]bool)}
+ val0 := Lock{Readers: make(map[UniqueID]OwnerInfo)}
for k, v := range val.Readers {
val0.Readers[k] = v
}
val0.Writer = val.Writer
+ val0.WriterInfo = val.WriterInfo
return val, val0
}
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
index fad90984b..9878c04e1 100644
--- a/pkg/sentry/fs/lock/lock_test.go
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -30,12 +30,12 @@ func equals(e0, e1 []entry) bool {
}
for i := range e0 {
for k := range e0[i].Lock.Readers {
- if !e1[i].Lock.Readers[k] {
+ if _, ok := e1[i].Lock.Readers[k]; !ok {
return false
}
}
for k := range e1[i].Lock.Readers {
- if !e0[i].Lock.Readers[k] {
+ if _, ok := e0[i].Lock.Readers[k]; !ok {
return false
}
}
@@ -90,15 +90,15 @@ func TestCanLock(t *testing.T) {
// 0 1024 2048 3072 4096
l := fill([]entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}},
LockRange: LockRange{1024, 2048},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 3: OwnerInfo{}}},
LockRange: LockRange{2048, 3072},
},
{
@@ -220,7 +220,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -266,7 +266,7 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, 4096},
},
{
@@ -283,7 +283,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -302,7 +302,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -333,7 +333,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -351,7 +351,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -366,11 +366,11 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -383,7 +383,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -398,15 +398,15 @@ func TestSetLock(t *testing.T) {
// 0 4096 8192 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{4096, 8192},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{8192, LockEOF},
},
},
@@ -419,7 +419,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -434,7 +434,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -447,7 +447,7 @@ func TestSetLock(t *testing.T) {
// 0 4096
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, 4096},
},
},
@@ -467,11 +467,11 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -484,7 +484,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -499,15 +499,15 @@ func TestSetLock(t *testing.T) {
// 0 1024 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{1024, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -520,15 +520,15 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, 2048},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -543,7 +543,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
@@ -551,7 +551,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{1024, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -564,15 +564,15 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, 2048},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -587,7 +587,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 3072 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
@@ -595,7 +595,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{1024, 3072},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -608,11 +608,11 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, 2048},
},
},
@@ -634,15 +634,15 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{1024, 2048},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{2048, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -676,7 +676,7 @@ func TestSetLock(t *testing.T) {
l := fill(test.before)
r := LockRange{Start: test.start, End: test.end}
- success := l.lock(test.uid, test.lockType, r)
+ success := l.lock(test.uid, 0 /* ownerPID */, test.lockType, r)
var got []entry
for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
got = append(got, entry{
@@ -739,7 +739,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -752,7 +752,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -765,7 +765,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -797,7 +797,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -810,7 +810,7 @@ func TestUnlock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -849,7 +849,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -862,7 +862,7 @@ func TestUnlock(t *testing.T) {
// 0 4096
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}},
LockRange: LockRange{0, 4096},
},
},
@@ -901,7 +901,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
@@ -909,7 +909,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{1024, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -922,11 +922,11 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -939,7 +939,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{0, LockEOF},
},
},
@@ -952,15 +952,15 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
after: []entry{
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{1024, 4096},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -977,7 +977,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -994,7 +994,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 8},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -1011,7 +1011,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -1028,11 +1028,11 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}},
LockRange: LockRange{4096, 8192},
},
{
- Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}},
+ Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}},
LockRange: LockRange{8192, LockEOF},
},
},
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index e91fa26a4..b44117f40 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -194,16 +193,6 @@ func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions
return mfd.inode.Stat(ctx, fs, opts)
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence)
-}
-
// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
switch cmd {
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
index 70c68cf0a..a0c5b5af5 100644
--- a/pkg/sentry/fsimpl/devpts/replica.go
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -189,13 +188,3 @@ func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOption
fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem()
return rfd.inode.Stat(ctx, fs, opts)
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 0ad79b381..512b70ede 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -311,13 +310,3 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
fd.off = offset
return offset, nil
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index 4a5539b37..5ad9befcd 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -154,13 +153,3 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 3b5927702..9da01cba3 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -90,6 +90,7 @@ type createSyntheticOpts struct {
// * d.isDir().
// * d does not already contain a child with the given name.
func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
+ now := d.fs.clock.Now().Nanoseconds()
child := &dentry{
refs: 1, // held by d
fs: d.fs,
@@ -98,6 +99,10 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
blockSize: usermem.PageSize, // arbitrary
+ atime: now,
+ mtime: now,
+ ctime: now,
+ btime: now,
readFD: -1,
writeFD: -1,
mmapFD: -1,
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 91d5dc174..8f95473b6 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -36,16 +36,26 @@ import (
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
// Snapshot current syncable dentries and special file FDs.
+ fs.renameMu.RLock()
fs.syncMu.Lock()
ds := make([]*dentry, 0, len(fs.syncableDentries))
for d := range fs.syncableDentries {
+ // It's safe to use IncRef here even though fs.syncableDentries doesn't
+ // hold references since we hold fs.renameMu. Note that we can't use
+ // TryIncRef since cached dentries at zero references should still be
+ // synced.
d.IncRef()
ds = append(ds, d)
}
+ fs.renameMu.RUnlock()
sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs))
for sffd := range fs.specialFileFDs {
- sffd.vfsfd.IncRef()
- sffds = append(sffds, sffd)
+ // As above, fs.specialFileFDs doesn't hold references. However, unlike
+ // dentries, an FD that has reached zero references can't be
+ // resurrected, so we can use TryIncRef.
+ if sffd.vfsfd.TryIncRef() {
+ sffds = append(sffds, sffd)
+ }
}
fs.syncMu.Unlock()
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 3cdb1e659..98f7bc52f 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1944,22 +1944,22 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
}
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
-func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
fd.lockLogging.Do(func() {
log.Infof("File lock using gofer file handled internally.")
})
- return fd.LockFD.LockBSD(ctx, uid, t, block)
+ return fd.LockFD.LockBSD(ctx, uid, ownerPID, t, block)
}
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
fd.lockLogging.Do(func() {
log.Infof("Range lock using gofer file handled internally.")
})
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+ return fd.Locks().LockPOSIX(ctx, uid, ownerPID, t, r, block)
}
// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
+ return fd.Locks().UnlockPOSIX(ctx, uid, r)
}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 36a3f6810..05f11fbd5 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -810,13 +809,3 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) {
func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask)
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 60acc367f..72aa535f8 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -201,7 +201,7 @@ func (c *ConnectedEndpoint) Passcred() bool {
}
// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
-func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index f5c596fec..0f9e20a84 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -370,13 +369,3 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
return syserror.ERESTARTSYS
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 485504995..65054b0ea 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -136,13 +135,3 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
// DynamicBytesFiles are immutable.
return syserror.EPERM
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index f8dae22f8..e55111af0 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -275,13 +274,3 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio
func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length)
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index eac578f25..8139bff76 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -371,6 +371,8 @@ type OrderedChildrenOptions struct {
// OrderedChildren may modify the tracked children. This applies to
// operations related to rename, unlink and rmdir. If an OrderedChildren is
// not writable, these operations all fail with EPERM.
+ //
+ // Note that writable users must implement the sticky bit (I_SVTX).
Writable bool
}
@@ -556,7 +558,6 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child Inode)
return err
}
- // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
o.removeLocked(name)
return nil
}
@@ -603,8 +604,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c
if err := o.checkExistingLocked(oldname, child); err != nil {
return err
}
+ o.removeLocked(oldname)
- // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
dst.replaceChildLocked(ctx, newname, child)
return nil
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 3492409b2..082fa6504 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -42,7 +42,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -820,13 +819,3 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 75be6129f..fdae163d1 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -518,16 +517,6 @@ func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error {
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *memFD) Release(context.Context) {}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
-
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -1110,13 +1099,3 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
func (fd *namespaceFD) Release(ctx context.Context) {
fd.inode.DecRef(ctx)
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 7ee6227a9..d6f076cd6 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -393,7 +393,7 @@ func TestProcSelf(t *testing.T) {
t.Fatalf("CreateTask(): %v", err)
}
- collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{
+ collector := s.WithTemporaryContext(task.AsyncContext()).ListDirents(&vfs.PathOperation{
Root: s.Root,
Start: s.Root,
Path: fspath.Parse("/proc/self/"),
@@ -491,11 +491,11 @@ func TestTree(t *testing.T) {
t.Fatalf("CreateTask(): %v", err)
}
// Add file to populate /proc/[pid]/fd and fdinfo directories.
- task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{})
+ task.FDTable().NewFDVFS2(task.AsyncContext(), 0, file, kernel.FDFlags{})
tasks = append(tasks, task)
}
- ctx := tasks[0]
+ ctx := tasks[0].AsyncContext()
fd, err := s.VFS.OpenAt(
ctx,
auth.CredentialsFromContext(s.Ctx),
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 146c7fdfe..4393cc13b 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -140,35 +140,35 @@ func TestLocks(t *testing.T) {
uid1 := 123
uid2 := 456
- if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil {
+ if err := fd.Impl().LockBSD(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil {
+ if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
}
if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err)
}
- if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil {
+ if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 1, End: 2}, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), syserror.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
}
- if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil {
+ if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil {
t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 0c9c639d3..b32c54e20 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -36,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -797,16 +796,6 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
return nil
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
-
// Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all
// filesystem state is in-memory.
func (*fileDescription) Sync(context.Context) error {
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index a5171b5ad..8645078a0 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -660,7 +660,6 @@ func (d *dentry) readlink(ctx context.Context) (string, error) {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
- vfs.LockFD
// d is the corresponding dentry to the fileDescription.
d *dentry
@@ -1104,14 +1103,29 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op
return 0, syserror.EROFS
}
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
+ return fd.lowerFD.LockBSD(ctx, ownerPID, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *fileDescription) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return fd.lowerFD.UnlockBSD(ctx)
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block)
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
+ return fd.lowerFD.LockPOSIX(ctx, uid, ownerPID, t, r, block)
}
// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
+ return fd.lowerFD.UnlockPOSIX(ctx, uid, r)
+}
+
+// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX.
+func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
+ return fd.lowerFD.TestPOSIX(ctx, uid, t, r)
}
// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 30d8b4355..798d6a9bd 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -66,7 +66,7 @@ func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry {
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
-func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, context.Context, error) {
t.Helper()
k, err := testutil.Boot()
if err != nil {
@@ -119,7 +119,7 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
root.DecRef(ctx)
mntns.DecRef(ctx)
})
- return vfsObj, root, task, nil
+ return vfsObj, root, task.AsyncContext(), nil
}
// openVerityAt opens a verity file.
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 0ee60569c..8a5b11d40 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -240,7 +240,6 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
- "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/fsimpl/timerfd",
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 7aba31587..a6afabb1c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -153,20 +153,12 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) {
// dropVFS2 drops the table reference.
func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) {
- // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
- // entire file.
- err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET)
+ // Release any POSIX lock possibly held by the FDTable.
+ err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF})
if err != nil && err != syserror.ENOLCK {
panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
}
- // Generate inotify events.
- ev := uint32(linux.IN_CLOSE_NOWRITE)
- if file.IsWritable() {
- ev = linux.IN_CLOSE_WRITE
- }
- file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent)
-
// Drop the table's reference.
file.DecRef(ctx)
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 303ae8056..ef4e934a1 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -593,8 +593,8 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
// Wrap this error in ErrSaveRejection so that it will trigger a save
// error, rather than a panic. This also allows us to distinguish Fsync
// errors from state file errors in state.Save.
- return fs.ErrSaveRejection{
- Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err),
+ return &fs.ErrSaveRejection{
+ Err: fmt.Errorf("%q was not sufficiently synced: %w", name, err),
}
}
return nil
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 2c32d017d..71daa9f4b 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -27,7 +27,6 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/fs/lock",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index d5a91730d..3b6336e94 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -441,13 +440,3 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr
}
return n, err
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index 9419f2e95..ecbe8f920 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.
// syserror.ErrInterrupted if t is interrupted.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error {
+func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error {
if !haveDeadline {
return t.block(C, nil)
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index fdadb52c0..e9da99067 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -216,7 +216,7 @@ func (ns *PIDNamespace) TaskWithID(tid ThreadID) *Task {
return t
}
-// ThreadGroupWithID returns the thread group lead by the task with thread ID
+// ThreadGroupWithID returns the thread group led by the task with thread ID
// tid in PID namespace ns. If no task has that TID, or if the task with that
// TID is not a thread group leader, ThreadGroupWithID returns nil.
func (ns *PIDNamespace) ThreadGroupWithID(tid ThreadID) *ThreadGroup {
@@ -292,6 +292,11 @@ func (ns *PIDNamespace) UserNamespace() *auth.UserNamespace {
return ns.userns
}
+// Root returns the root PID namespace of ns.
+func (ns *PIDNamespace) Root() *PIDNamespace {
+ return ns.owner.Root
+}
+
// A threadGroupNode defines the relationship between a thread group and the
// rest of the system. Conceptually, threadGroupNode is data belonging to the
// owning TaskSet, as if TaskSet contained a field `nodes
@@ -485,3 +490,8 @@ func (t *Task) Parent() *Task {
func (t *Task) ThreadID() ThreadID {
return t.tg.pidns.IDOfTask(t)
}
+
+// TGIDInRoot returns t's TGID in the root PID namespace.
+func (t *Task) TGIDInRoot() ThreadID {
+ return t.tg.pidns.owner.Root.IDOfThreadGroup(t.tg)
+}
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 6efe5102b..73bfbea49 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -17,7 +17,6 @@ package mm
import (
"bytes"
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
@@ -165,12 +164,12 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
}
if s != "" {
// Per linux, we pad until the 74th character.
- if pad := 73 - lineLen; pad > 0 {
- b.WriteString(strings.Repeat(" ", pad))
+ for pad := 73 - lineLen; pad > 0; pad-- {
+ b.WriteByte(' ')
}
b.WriteString(s)
}
- b.WriteString("\n")
+ b.WriteByte('\n')
}
// ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 675efdc7c..69e37330b 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -1055,18 +1055,11 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
mm.activeMu.Lock()
defer mm.activeMu.Unlock()
- // Linux's mm/madvise.c:madvise_dontneed() => mm/memory.c:zap_page_range()
- // is analogous to our mm.invalidateLocked(ar, true, true). We inline this
- // here, with the special case that we synchronously decommit
- // uniquely-owned (non-copy-on-write) pages for private anonymous vma,
- // which is the common case for MADV_DONTNEED. Invalidating these pmas, and
- // allowing them to be reallocated when touched again, increases pma
- // fragmentation, which may significantly reduce performance for
- // non-vectored I/O implementations. Also, decommitting synchronously
- // ensures that Decommit immediately reduces host memory usage.
+ // This is invalidateLocked(invalidatePrivate=true, invalidateShared=true),
+ // with the additional wrinkle that we must refuse to invalidate pmas under
+ // mlocked vmas.
var didUnmapAS bool
pseg := mm.pmas.LowerBoundSegment(ar.Start)
- mf := mm.mfp.MemoryFile()
for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
vma := vseg.ValuePtr()
if vma.mlockMode != memmap.MLockNone {
@@ -1081,20 +1074,8 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
}
}
for pseg.Ok() && pseg.Start() < vsegAR.End {
- pma := pseg.ValuePtr()
- if pma.private && !mm.isPMACopyOnWriteLocked(vseg, pseg) {
- psegAR := pseg.Range().Intersect(ar)
- if vsegAR.IsSupersetOf(psegAR) && vma.mappable == nil {
- if err := mf.Decommit(pseg.fileRangeOf(psegAR)); err == nil {
- pseg = pseg.NextSegment()
- continue
- }
- // If an error occurs, fall through to the general
- // invalidation case below.
- }
- }
pseg = mm.pmas.Isolate(pseg, vsegAR)
- pma = pseg.ValuePtr()
+ pma := pseg.ValuePtr()
if !didUnmapAS {
// Unmap all of ar, not just pseg.Range(), to minimize host
// syscalls. AddressSpace mappings must be removed before
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index aacd7ce70..17fb0a0d8 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -550,6 +550,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// Wait for the syscall-enter stop.
sig := t.wait(stopped)
+ if sig == syscall.SIGSTOP {
+ // SIGSTOP was delivered to another thread in the same thread
+ // group, which initiated another group stop. Just ignore it.
+ continue
+ }
+
// Refresh all registers.
if err := t.getRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
@@ -566,13 +572,11 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// Is it a system call?
if sig == (syscallEvent | syscall.SIGTRAP) {
+ s.arm64SyscallWorkaround(t, regs)
+
// Ensure registers are sane.
updateSyscallRegs(regs)
return true
- } else if sig == syscall.SIGSTOP {
- // SIGSTOP was delivered to another thread in the same thread
- // group, which initiated another group stop. Just ignore it.
- continue
}
// Grab signal information.
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 020bbda79..04815282b 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -257,3 +257,6 @@ func probeSeccomp() bool {
}
}
}
+
+func (s *subprocess) arm64SyscallWorkaround(t *thread, regs *arch.Registers) {
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index bd618fae8..416132967 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -21,6 +21,7 @@ import (
"strings"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -172,3 +173,38 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi
func probeSeccomp() bool {
return true
}
+
+func (s *subprocess) arm64SyscallWorkaround(t *thread, regs *arch.Registers) {
+ // On ARM64, when ptrace stops on a system call, it uses the x7
+ // register to indicate whether the stop has been signalled from
+ // syscall entry or syscall exit. This means that we can't get a value
+ // of this register and we can't change it. More details are in the
+ // comment for tracehook_report_syscall in arch/arm64/kernel/ptrace.c.
+ //
+ // This happens only if we stop on a system call, so let's queue a
+ // signal, resume a stub thread and catch it on a signal handling.
+ t.NotifyInterrupt()
+ for {
+ if _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ unix.PTRACE_SYSEMU,
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
+ }
+
+ // Wait for the syscall-enter stop.
+ sig := t.wait(stopped)
+ if sig == syscall.SIGSTOP {
+ // SIGSTOP was delivered to another thread in the same thread
+ // group, which initiated another group stop. Just ignore it.
+ continue
+ }
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ t.dumpAndPanic(fmt.Sprintf("unexpected syscall event"))
+ }
+ break
+ }
+ if err := t.getRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index b6ebe29d6..a8e6f172b 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -28,7 +28,6 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/hostfd",
"//pkg/sentry/inet",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 5b868216d..17f59ba1f 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -377,10 +377,8 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
- case linux.IP_PKTINFO:
- optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 9a2cac40b..f82c7c224 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -144,16 +143,6 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
return int64(n), err
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
-}
-
type socketProviderVFS2 struct {
family int
}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 70c561cce..2f913787b 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -15,7 +15,6 @@
package netfilter
import (
- "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
}
- n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterface)
- }
- ifname := string(iptip.OutputInterface[:n])
-
- n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterfaceMask)
- }
- ifnameMask := string(iptip.OutputInterfaceMask[:n])
-
return stack.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
// A Protocol value of 0 indicates all protocols match.
@@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
Src: tcpip.Address(iptip.Src[:]),
SrcMask: tcpip.Address(iptip.SrcMask[:]),
SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
- OutputInterface: ifname,
- OutputInterfaceMask: ifnameMask,
+ InputInterface: string(trimNullBytes(iptip.InputInterface[:])),
+ InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])),
+ InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0,
+ OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])),
+ OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])),
OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
}, nil
}
@@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool {
// - Dst and DstMask
// - Src and SrcMask
// - The inverse destination IP check flag
+ // - InputInterface, InputInterfaceMask and its inverse.
// - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInterface = [linux.IFNAMSIZ]byte{}
+ const flagMask = 0
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
- return iptip.InputInterface != emptyInterface ||
- iptip.InputInterfaceMask != emptyInterface ||
- iptip.Flags != 0 ||
+ const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP |
+ linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT
+ return iptip.Flags&^flagMask != 0 ||
iptip.InverseFlags&^inverseMask != 0
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 5dbb604f0..263d9d3b5 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -15,7 +15,6 @@
package netfilter
import (
- "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
}
- n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterface)
- }
- ifname := string(iptip.OutputInterface[:n])
-
- n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterfaceMask)
- }
- ifnameMask := string(iptip.OutputInterfaceMask[:n])
-
return stack.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
// In ip6tables a flag controls whether to check the protocol.
@@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
Src: tcpip.Address(iptip.Src[:]),
SrcMask: tcpip.Address(iptip.SrcMask[:]),
SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0,
- OutputInterface: ifname,
- OutputInterfaceMask: ifnameMask,
+ InputInterface: string(trimNullBytes(iptip.InputInterface[:])),
+ InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])),
+ InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0,
+ OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])),
+ OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])),
OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0,
}, nil
}
@@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool {
// - Dst and DstMask
// - Src and SrcMask
// - The inverse destination IP check flag
+ // - InputInterface, InputInterfaceMask and its inverse.
// - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInterface = [linux.IFNAMSIZ]byte{}
- flagMask := uint8(linux.IP6T_F_PROTO)
+ const flagMask = linux.IP6T_F_PROTO
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT)
- return iptip.InputInterface != emptyInterface ||
- iptip.InputInterfaceMask != emptyInterface ||
- iptip.Flags&^flagMask != 0 ||
+ const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP |
+ linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT
+ return iptip.Flags&^flagMask != 0 ||
iptip.InverseFlags&^inverseMask != 0 ||
iptip.TOS != 0
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 26bd1abd4..7ae18b2a3 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,6 +17,7 @@
package netfilter
import (
+ "bytes"
"errors"
"fmt"
@@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP
rev.Revision = maxSupported
return rev, nil
}
+
+func trimNullBytes(b []byte) []byte {
+ n := bytes.IndexByte(b, 0)
+ if n == -1 {
+ n = len(b)
+ }
+ return b[:n]
+}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 69d13745e..176fa6116 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// Support only for OUTPUT chain.
// TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 352c51390..2740697b3 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index c88d8268d..466d5395d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 1f926aa91..9313e1167 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index 461d524e5..842036764 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -149,13 +148,3 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
return int64(n), err.ToError()
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 22abca120..915134b41 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -28,7 +28,6 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
@@ -42,7 +41,6 @@ go_library(
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 22e128b96..94f03af48 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -19,7 +19,7 @@
// be used to expose certain endpoints to the sentry while leaving others out,
// for example, TCP endpoints and Unix-domain endpoints.
//
-// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside
// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
// this operation.
package netstack
@@ -55,7 +55,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -194,7 +193,6 @@ var Metrics = tcpip.Stats{
RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."),
OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."),
OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."),
- OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."),
OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."),
OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."),
RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."),
@@ -253,11 +251,11 @@ var errStackType = syserr.New("expected but did not receive a netstack.Stack", l
type commonEndpoint interface {
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and
// transport.Endpoint.GetLocalAddress.
- GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and
// transport.Endpoint.GetRemoteAddress.
- GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetRemoteAddress() (tcpip.FullAddress, tcpip.Error)
// Readiness implements tcpip.Endpoint.Readiness and
// transport.Endpoint.Readiness.
@@ -265,19 +263,19 @@ type commonEndpoint interface {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt and
// transport.Endpoint.SetSockOpt.
- SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
+ SetSockOpt(tcpip.SettableSocketOption) tcpip.Error
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
- GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
+ GetSockOpt(tcpip.GettableSocketOption) tcpip.Error
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -285,7 +283,7 @@ type commonEndpoint interface {
// LastError implements tcpip.Endpoint.LastError and
// transport.Endpoint.LastError.
- LastError() *tcpip.Error
+ LastError() tcpip.Error
// SocketOptions implements tcpip.Endpoint.SocketOptions and
// transport.Endpoint.SocketOptions.
@@ -440,115 +438,58 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
return int64(res.Count), nil
}
-// ioSequencePayload implements tcpip.Payload.
-//
-// t copies user memory bytes on demand based on the requested size.
-type ioSequencePayload struct {
- ctx context.Context
- src usermem.IOSequence
-}
-
-// FullPayload implements tcpip.Payloader.FullPayload
-func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
- return i.Payload(int(i.src.NumBytes()))
-}
-
-// Payload implements tcpip.Payloader.Payload.
-func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
- if max := int(i.src.NumBytes()); size > max {
- size = max
- }
- v := buffer.NewView(size)
- if _, err := i.src.CopyIn(i.ctx, v); err != nil {
- // EOF can be returned only if src is a file and this means it
- // is in a splice syscall and the error has to be ignored.
- if err == io.EOF {
- return v, nil
- }
- return nil, tcpip.ErrBadAddress
- }
- return v, nil
-}
-
-// DropFirst drops the first n bytes from underlying src.
-func (i *ioSequencePayload) DropFirst(n int) {
- i.src = i.src.DropFirst(int(n))
-}
-
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
- if err == tcpip.ErrWouldBlock {
+ r := src.Reader(ctx)
+ n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
return 0, syserror.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
- if int64(n) < src.NumBytes() {
- return int64(n), syserror.ErrWouldBlock
+ if n < src.NumBytes() {
+ return n, syserror.ErrWouldBlock
}
- return int64(n), nil
+ return n, nil
}
-// readerPayload implements tcpip.Payloader.
-//
-// It allocates a view and reads from a reader on-demand, based on available
-// capacity in the endpoint.
-type readerPayload struct {
- ctx context.Context
- r io.Reader
- count int64
+var _ tcpip.Payloader = (*limitedPayloader)(nil)
+
+type limitedPayloader struct {
+ inner io.LimitedReader
err error
}
-// FullPayload implements tcpip.Payloader.FullPayload.
-func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
- return r.Payload(int(r.count))
+func (l *limitedPayloader) Read(p []byte) (int, error) {
+ n, err := l.inner.Read(p)
+ l.err = err
+ return n, err
}
-// Payload implements tcpip.Payloader.Payload.
-func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
- if size > int(r.count) {
- size = int(r.count)
- }
- v := buffer.NewView(size)
- n, err := r.r.Read(v)
- if n > 0 {
- // We ignore the error here. It may re-occur on subsequent
- // reads, but for now we can enqueue some amount of data.
- r.count -= int64(n)
- return v[:n], nil
- }
- if err == syserror.ErrWouldBlock {
- return nil, tcpip.ErrWouldBlock
- } else if err != nil {
- r.err = err // Save for propation.
- return nil, tcpip.ErrBadAddress
- }
-
- // There is no data and no error. Return an error, which will propagate
- // r.err, which will be nil. This is the desired result: (0, nil).
- return nil, tcpip.ErrBadAddress
+func (l *limitedPayloader) Len() int {
+ return int(l.inner.N)
}
// ReadFrom implements fs.FileOperations.ReadFrom.
func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
- f := &readerPayload{ctx: ctx, r: r, count: count}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ f := limitedPayloader{
+ inner: io.LimitedReader{
+ R: r,
+ N: count,
+ },
+ }
+ n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{
// Reads may be destructive but should be very fast,
// so we can't release the lock while copying data.
Atomic: true,
})
- if err == tcpip.ErrWouldBlock {
- return n, syserror.ErrWouldBlock
- } else if err != nil {
- return int64(n), f.err // Propagate error.
+ if _, ok := err.(*tcpip.ErrBadBuffer); ok {
+ return n, f.err
}
-
- return int64(n), nil
+ return n, syserr.TranslateNetstackError(err).ToError()
}
// Readiness returns a mask of ready events for socket s.
@@ -592,7 +533,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
if family == linux.AF_UNSPEC {
err := s.Endpoint.Disconnect()
- if err == tcpip.ErrNotSupported {
+ if _, ok := err.(*tcpip.ErrNotSupported); ok {
return syserr.ErrAddressFamilyNotSupported
}
return syserr.TranslateNetstackError(err)
@@ -614,15 +555,16 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
- if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ switch err := s.Endpoint.Connect(addr); err.(type) {
+ case *tcpip.ErrConnectStarted, *tcpip.ErrAlreadyConnecting:
+ case *tcpip.ErrNoPortAvailable:
if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
// TCP unlike UDP returns EADDRNOTAVAIL when it can't
// find an available local ephemeral port.
- if err == tcpip.ErrNoPortAvailable {
- return syserr.ErrAddressNotAvailable
- }
+ return syserr.ErrAddressNotAvailable
}
-
+ return syserr.TranslateNetstackError(err)
+ default:
return syserr.TranslateNetstackError(err)
}
@@ -680,16 +622,16 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Issue the bind request to the endpoint.
err := s.Endpoint.Bind(addr)
- if err == tcpip.ErrNoPortAvailable {
+ if _, ok := err.(*tcpip.ErrNoPortAvailable); ok {
// Bind always returns EADDRINUSE irrespective of if the specified port was
// already bound or if an ephemeral port was requested but none were
// available.
//
- // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because
+ // *tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because
// UDP connect returns EAGAIN on ephemeral port exhaustion.
//
// TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion.
- err = tcpip.ErrPortInUse
+ err = &tcpip.ErrPortInUse{}
}
return syserr.TranslateNetstackError(err)
@@ -712,7 +654,8 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAdd
// Try to accept the connection again; if it fails, then wait until we
// get a notification.
for {
- if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock {
+ ep, wq, err := s.Endpoint.Accept(peerAddr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
return ep, wq, syserr.TranslateNetstackError(err)
}
@@ -731,7 +674,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
- if terr != tcpip.ErrWouldBlock || !blocking {
+ if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
@@ -912,7 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ size, err := ep.SocketOptions().GetSendBufferSize()
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1164,6 +1107,29 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
// TODO(b/64800844): Translate fields once they are added to
// tcpip.TCPInfoOption.
info := linux.TCPInfo{}
+ switch v.CcState {
+ case tcpip.RTORecovery:
+ info.CaState = linux.TCP_CA_Loss
+ case tcpip.FastRecovery, tcpip.SACKRecovery:
+ info.CaState = linux.TCP_CA_Recovery
+ case tcpip.Disorder:
+ info.CaState = linux.TCP_CA_Disorder
+ case tcpip.Open:
+ info.CaState = linux.TCP_CA_Open
+ }
+ info.RTO = uint32(v.RTO / time.Microsecond)
+ info.RTT = uint32(v.RTT / time.Microsecond)
+ info.RTTVar = uint32(v.RTTVar / time.Microsecond)
+ info.SndSsthresh = v.SndSsthresh
+ info.SndCwnd = v.SndCwnd
+
+ // In netstack reorderSeen is updated only when RACK is enabled.
+ // We only track whether the reordering is seen, which is
+ // different than Linux where reorderSeen is not specific to
+ // RACK and is incremented when a reordering event is seen.
+ if v.ReorderSeen {
+ info.ReordSeen = 1
+ }
// Linux truncates the output binary to outLen.
buf := t.CopyScratchBuffer(info.SizeBytes())
@@ -1681,8 +1647,16 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
+ family, _, _ := s.Type()
+ // TODO(gvisor.dev/issue/5132): We currently do not support
+ // setting this option for unix sockets.
+ if family == linux.AF_UNIX {
+ return nil
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
+ ep.SocketOptions().SetSendBufferSize(int64(v), true)
+ return nil
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1814,10 +1788,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
var v linux.Linger
binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v)
- if v != (linux.Linger{}) {
- socket.SetSockOptEmitUnimplementedEvent(t, name)
- }
-
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: v.OnOff != 0,
Timeout: time.Second * time.Duration(v.Linger),
@@ -2596,7 +2566,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
defer s.readMu.Unlock()
res, err := s.Endpoint.Read(w, readOptions)
- if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 {
+ if _, ok := err.(*tcpip.ErrBadBuffer); ok && dst.NumBytes() == 0 {
err = nil
}
if err != nil {
@@ -2840,45 +2810,48 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- v := &ioSequencePayload{t, src}
- n, err := s.Endpoint.Write(v, opts)
- dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= v.src.NumBytes() || dontWait) {
- // Complete write.
- return int(n), nil
- }
- if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
- return int(n), syserr.TranslateNetstackError(err)
- }
-
- // We'll have to block. Register for notification and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut)
- defer s.EventUnregister(&e)
-
- v.DropFirst(int(n))
- total := n
+ r := src.Reader(t)
+ var (
+ total int64
+ entry waiter.Entry
+ ch <-chan struct{}
+ )
for {
- n, err = s.Endpoint.Write(v, opts)
- v.DropFirst(int(n))
+ n, err := s.Endpoint.Write(r, opts)
total += n
-
- if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
- return 0, syserr.TranslateNetstackError(err)
- }
-
- if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
- return int(total), nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return int(total), syserr.ErrTryAgain
+ if flags&linux.MSG_DONTWAIT != 0 {
+ return int(total), syserr.TranslateNetstackError(err)
+ }
+ block := true
+ switch err.(type) {
+ case nil:
+ block = total != src.NumBytes()
+ case *tcpip.ErrWouldBlock:
+ default:
+ block = false
+ }
+ if block {
+ if ch == nil {
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ entry, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&entry, waiter.EventOut)
+ defer s.EventUnregister(&entry)
+ } else {
+ // Don't wait immediately after registration in case more data
+ // became available between when we last checked and when we setup
+ // the notification.
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
+ // handleIOError will consume errors from t.Block if needed.
+ return int(total), syserr.FromError(err)
+ }
}
- // handleIOError will consume errors from t.Block if needed.
- return int(total), syserr.FromError(err)
+ continue
}
+ return int(total), syserr.TranslateNetstackError(err)
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 6f70b02fc..24922c400 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -129,20 +128,20 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
return 0, syserror.EOPNOTSUPP
}
- f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
- if err == tcpip.ErrWouldBlock {
+ r := src.Reader(ctx)
+ n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
return 0, syserror.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
- if int64(n) < src.NumBytes() {
- return int64(n), syserror.ErrWouldBlock
+ if n < src.NumBytes() {
+ return n, syserror.ErrWouldBlock
}
- return int64(n), nil
+ return n, nil
}
// Accept implements the linux syscall accept(2) for sockets backed by
@@ -155,7 +154,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
}
ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
- if terr != tcpip.ErrWouldBlock || !blocking {
+ if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
@@ -262,13 +261,3 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index c847ff1c7..2515dda80 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -118,7 +118,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
// Create the endpoint.
var ep tcpip.Endpoint
- var e *tcpip.Error
+ var e tcpip.Error
wq := &waiter.Queue{}
if stype == linux.SOCK_RAW {
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 0af805246..ba1cc79e9 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -62,7 +62,7 @@ func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int
// Create the endpoint.
var ep tcpip.Endpoint
- var e *tcpip.Error
+ var e tcpip.Error
wq := &waiter.Queue{}
if stype == linux.SOCK_RAW {
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 9f7aca305..fc5b823b0 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -48,7 +48,7 @@ type ConnectingEndpoint interface {
Type() linux.SockType
// GetLocalAddress returns the bound path.
- GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
// Locker protects the following methods. While locked, only the holder of
// the lock can change the return value of the protected methods.
@@ -128,7 +128,7 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, nil, nil)
return ep
}
@@ -173,7 +173,7 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, nil, nil)
return ep
}
@@ -296,7 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
- ne.ops.InitHandler(ne)
+ ne.ops.InitHandler(ne, nil, nil)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
readQueue.InitRefs()
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 0813ad87d..20fa8b874 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,7 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, nil, nil)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 099a56281..70227bbd2 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -169,32 +169,32 @@ type Endpoint interface {
Type() linux.SockType
// GetLocalAddress returns the address to which the endpoint is bound.
- GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
- GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetRemoteAddress() (tcpip.FullAddress, tcpip.Error)
// SetSockOpt sets a socket option.
- SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
+ SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error
// GetSockOpt gets a socket option.
- GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
+ GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
State() uint32
// LastError clears and returns the last error reported by the endpoint.
- LastError() *tcpip.Error
+ LastError() tcpip.Error
// SocketOptions returns the structure which contains all the socket
// level options.
@@ -580,7 +580,7 @@ type ConnectedEndpoint interface {
Passcred() bool
// GetLocalAddress implements Endpoint.GetLocalAddress.
- GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
// Send sends a single message. This method does not block.
//
@@ -640,7 +640,7 @@ type connectedEndpoint struct {
Passcred() bool
// GetLocalAddress implements Endpoint.GetLocalAddress.
- GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+ GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
// Type implements Endpoint.Type.
Type() linux.SockType
@@ -655,7 +655,7 @@ func (e *connectedEndpoint) Passcred() bool {
}
// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
-func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
return e.endpoint.GetLocalAddress()
}
@@ -836,13 +836,12 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
}
// SetSockOpt sets a socket option.
-func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
- case tcpip.SendBufferSizeOption:
case tcpip.ReceiveBufferSizeOption:
default:
log.Warningf("Unsupported socket option: %d", opt)
@@ -850,19 +849,40 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket.
+func (e *baseEndpoint) IsUnixSocket() bool {
+ return true
+}
+
+// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize.
+func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Connected() {
+ return -1, &tcpip.ErrNotConnected{}
+ }
+
+ v := e.connected.SendMaxQueueSize()
+ if v < 0 {
+ return -1, &tcpip.ErrQueueSizeNotSupported{}
+ }
+ return v, nil
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
e.Lock()
if !e.Connected() {
e.Unlock()
- return -1, tcpip.ErrNotConnected
+ return -1, &tcpip.ErrNotConnected{}
}
v = int(e.receiver.RecvQueuedSize())
e.Unlock()
if v < 0 {
- return -1, tcpip.ErrQueueSizeNotSupported
+ return -1, &tcpip.ErrQueueSizeNotSupported{}
}
return v, nil
@@ -870,25 +890,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.Lock()
if !e.Connected() {
e.Unlock()
- return -1, tcpip.ErrNotConnected
+ return -1, &tcpip.ErrNotConnected{}
}
v := e.connected.SendQueuedSize()
e.Unlock()
if v < 0 {
- return -1, tcpip.ErrQueueSizeNotSupported
- }
- return int(v), nil
-
- case tcpip.SendBufferSizeOption:
- e.Lock()
- if !e.Connected() {
- e.Unlock()
- return -1, tcpip.ErrNotConnected
- }
- v := e.connected.SendMaxQueueSize()
- e.Unlock()
- if v < 0 {
- return -1, tcpip.ErrQueueSizeNotSupported
+ return -1, &tcpip.ErrQueueSizeNotSupported{}
}
return int(v), nil
@@ -896,29 +903,29 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.Lock()
if e.receiver == nil {
e.Unlock()
- return -1, tcpip.ErrNotConnected
+ return -1, &tcpip.ErrNotConnected{}
}
v := e.receiver.RecvMaxQueueSize()
e.Unlock()
if v < 0 {
- return -1, tcpip.ErrQueueSizeNotSupported
+ return -1, &tcpip.ErrQueueSizeNotSupported{}
}
return int(v), nil
default:
log.Warningf("Unsupported socket option: %d", opt)
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
// LastError implements Endpoint.LastError.
-func (*baseEndpoint) LastError() *tcpip.Error {
+func (*baseEndpoint) LastError() tcpip.Error {
return nil
}
@@ -958,7 +965,7 @@ func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
}
// GetLocalAddress returns the bound path.
-func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.Lock()
defer e.Unlock()
return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
@@ -966,14 +973,14 @@ func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress returns the local address of the connected endpoint (if
// available).
-func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.Lock()
c := e.connected
e.Unlock()
if c != nil {
return c.GetLocalAddress()
}
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Release implements BoundEndpoint.Release.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 6c4ec55b2..32e5d2304 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -496,6 +496,9 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
return int(n), syserr.FromError(err)
}
+ // Only send SCM Rights once (see net/unix/af_unix.c:unix_stream_sendmsg).
+ w.Control.Rights = nil
+
// We'll have to block. Register for notification and keep trying to
// send all the data.
e, ch := waiter.NewChannelEntry(nil)
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 27f705bb2..a7d4d7f1f 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -331,16 +330,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
-}
-
// providerVFS2 is a unix domain socket provider for VFS2.
type providerVFS2 struct{}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index a72df62f6..62d1e8f8b 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -503,8 +503,8 @@ var ARM64 = &kernel.SyscallTable{
72: syscalls.Supported("pselect", Pselect),
73: syscalls.Supported("ppoll", Ppoll),
74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
- 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 76: syscalls.Supported("splice", Splice),
77: syscalls.Supported("tee", Tee),
78: syscalls.Supported("readlinkat", Readlinkat),
79: syscalls.Supported("fstatat", Fstatat),
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index c33571f43..a6253626e 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1014,12 +1014,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
@@ -1030,12 +1030,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
@@ -2167,24 +2167,24 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.LOCK_EX:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_SH:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 7dd9ef857..1a31898e8 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -123,6 +123,15 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
defer file.DecRef(t)
+ if file.StatusFlags()&linux.O_PATH != 0 {
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC, linux.F_GETFD, linux.F_SETFD, linux.F_GETFL:
+ // allowed
+ default:
+ return 0, nil, syserror.EBADF
+ }
+ }
+
switch cmd {
case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
minfd := args[2].Int()
@@ -205,8 +214,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
err := tmpfs.AddSeals(file, args[2].Uint())
return 0, nil, err
- case linux.F_SETLK, linux.F_SETLKW:
- return 0, nil, posixLock(t, args, file, cmd)
+ case linux.F_SETLK:
+ return 0, nil, posixLock(t, args, file, false /* blocking */)
+ case linux.F_SETLKW:
+ return 0, nil, posixLock(t, args, file, true /* blocking */)
+ case linux.F_GETLK:
+ return 0, nil, posixTestLock(t, args, file)
case linux.F_GETSIG:
a := file.AsyncHandler()
if a == nil {
@@ -292,7 +305,49 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType,
}
}
-func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error {
+func posixTestLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription) error {
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock linux.Flock
+ if _, err := flock.CopyIn(t, flockAddr); err != nil {
+ return err
+ }
+ var typ lock.LockType
+ switch flock.Type {
+ case linux.F_RDLCK:
+ typ = lock.ReadLock
+ case linux.F_WRLCK:
+ typ = lock.WriteLock
+ default:
+ return syserror.EINVAL
+ }
+ r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence)
+ if err != nil {
+ return err
+ }
+
+ newFlock, err := file.TestPOSIX(t, t.FDTable(), typ, r)
+ if err != nil {
+ return err
+ }
+ newFlock.PID = translatePID(t.PIDNamespace().Root(), t.PIDNamespace(), newFlock.PID)
+ if _, err = newFlock.CopyOut(t, flockAddr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// translatePID translates a pid from one namespace to another. Note that this
+// may race with task termination/creation, in which case the original task
+// corresponding to pid may no longer exist. This is used to implement the
+// F_GETLK fcntl, which has the same potential race in Linux as well (i.e.,
+// there is no synchronization between retrieving the lock PID and translating
+// it). See fs/locks.c:posix_lock_to_flock.
+func translatePID(old, new *kernel.PIDNamespace, pid int32) int32 {
+ return int32(new.IDOfTask(old.TaskWithID(kernel.ThreadID(pid))))
+}
+
+func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, blocking bool) error {
// Copy in the lock request.
flockAddr := args[2].Pointer()
var flock linux.Flock
@@ -301,25 +356,30 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip
}
var blocker lock.Blocker
- if cmd == linux.F_SETLKW {
+ if blocking {
blocker = t
}
+ r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence)
+ if err != nil {
+ return err
+ }
+
switch flock.Type {
case linux.F_RDLCK:
if !file.IsReadable() {
return syserror.EBADF
}
- return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+ return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.ReadLock, r, blocker)
case linux.F_WRLCK:
if !file.IsWritable() {
return syserror.EBADF
}
- return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+ return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.WriteLock, r, blocker)
case linux.F_UNLCK:
- return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence)
+ return file.UnlockPOSIX(t, t.FDTable(), r)
default:
return syserror.EINVAL
@@ -344,6 +404,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef(t)
+ if file.StatusFlags()&linux.O_PATH != 0 {
+ return 0, nil, syserror.EBADF
+ }
+
// If the FD refers to a pipe or FIFO, return error.
if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe {
return 0, nil, syserror.ESPIPE
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 20c264fef..c7c3fed57 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -32,6 +32,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
defer file.DecRef(t)
+ if file.StatusFlags()&linux.O_PATH != 0 {
+ return 0, nil, syserror.EBADF
+ }
+
// Handle ioctls that apply to all FDs.
switch args[1].Int() {
case linux.FIONCLEX:
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
index b910b5a74..d1452a04d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/lock.go
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -44,11 +44,11 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch operation {
case linux.LOCK_EX:
- if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil {
+ if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.WriteLock, blocker); err != nil {
return 0, nil, err
}
case linux.LOCK_SH:
- if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil {
+ if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.ReadLock, blocker); err != nil {
return 0, nil, err
}
case linux.LOCK_UN:
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index b77b29dcc..c7417840f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -93,17 +93,11 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
n, err := file.Read(t, dst, opts)
if err != syserror.ErrWouldBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return n, err
}
@@ -134,9 +128,6 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
}
file.EventUnregister(&w)
- if total > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return total, err
}
@@ -257,17 +248,11 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
n, err := file.PRead(t, dst, offset, opts)
if err != syserror.ErrWouldBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return n, err
}
@@ -297,10 +282,6 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
}
}
file.EventUnregister(&w)
-
- if total > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return total, err
}
@@ -363,17 +344,11 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
n, err := file.Write(t, src, opts)
if err != syserror.ErrWouldBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- }
return n, err
}
@@ -403,10 +378,6 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
}
}
file.EventUnregister(&w)
-
- if total > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- }
return total, err
}
@@ -527,17 +498,11 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
n, err := file.PWrite(t, src, offset, opts)
if err != syserror.ErrWouldBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
- if n > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return n, err
}
@@ -567,10 +532,6 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
}
}
file.EventUnregister(&w)
-
- if total > 0 {
- file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- }
return total, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 1ee37e5a8..903169dc2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -220,7 +220,6 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
length := args[3].Int64()
file := t.GetFileVFS2(fd)
-
if file == nil {
return 0, nil, syserror.EBADF
}
@@ -229,23 +228,18 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if !file.IsWritable() {
return 0, nil, syserror.EBADF
}
-
if mode != 0 {
return 0, nil, syserror.ENOTSUP
}
-
if offset < 0 || length <= 0 {
return 0, nil, syserror.EINVAL
}
size := offset + length
-
if size < 0 {
return 0, nil, syserror.EFBIG
}
-
limit := limits.FromContext(t).Get(limits.FileSize).Cur
-
if uint64(size) >= limit {
t.SendSignal(&arch.SignalInfo{
Signo: int32(linux.SIGXFSZ),
@@ -254,12 +248,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EFBIG
}
- if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil {
- return 0, nil, err
- }
-
- file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- return 0, nil, nil
+ return 0, nil, file.Allocate(t, mode, uint64(offset), uint64(length))
}
// Utime implements Linux syscall utime(2).
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 8bb763a47..19e175203 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -170,13 +170,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
}
- if n != 0 {
- // On Linux, inotify behavior is not very consistent with splice(2). We try
- // our best to emulate Linux for very basic calls to splice, where for some
- // reason, events are generated for output files, but not input files.
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- }
-
// We can only pass a single file to handleIOError, so pick inFile arbitrarily.
// This is used only for debugging purposes.
return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile)
@@ -256,8 +249,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
}
if n != 0 {
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
-
// If a partial write is completed, the error is dropped. Log it here.
if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
log.Debugf("tee completed a partial write with error: %v", err)
@@ -449,9 +440,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
if total != 0 {
- inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
-
if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
// If a partial write is completed, the error is dropped. Log it here.
log.Debugf("sendfile completed a partial write with error: %v", err)
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
index 6e9b599e2..1f8a5878c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/sync.go
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -36,6 +36,10 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
defer file.DecRef(t)
+ if file.StatusFlags()&linux.O_PATH != 0 {
+ return 0, nil, syserror.EBADF
+ }
+
return 0, nil, file.SyncFS(t)
}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index a3868bf16..df4990854 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -83,6 +83,7 @@ go_library(
"mount.go",
"mount_namespace_refs.go",
"mount_unsafe.go",
+ "opath.go",
"options.go",
"pathname.go",
"permissions.go",
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 5321ac80a..f612a71b2 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -161,6 +161,13 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou
// DecRef decrements fd's reference count.
func (fd *FileDescription) DecRef(ctx context.Context) {
fd.FileDescriptionRefs.DecRef(func() {
+ // Generate inotify events.
+ ev := uint32(linux.IN_CLOSE_NOWRITE)
+ if fd.IsWritable() {
+ ev = linux.IN_CLOSE_WRITE
+ }
+ fd.Dentry().InotifyWithParent(ctx, ev, 0, PathEvent)
+
// Unregister fd from all epoll instances.
fd.epollMu.Lock()
epolls := fd.epolls
@@ -448,16 +455,19 @@ type FileDescriptionImpl interface {
RemoveXattr(ctx context.Context, name string) error
// LockBSD tries to acquire a BSD-style advisory file lock.
- LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
+ LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error
// UnlockBSD releases a BSD-style advisory file lock.
UnlockBSD(ctx context.Context, uid lock.UniqueID) error
// LockPOSIX tries to acquire a POSIX-style advisory file lock.
- LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error
+ LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error
// UnlockPOSIX releases a POSIX-style advisory file lock.
- UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error
+ UnlockPOSIX(ctx context.Context, uid lock.UniqueID, ComputeLockRange lock.LockRange) error
+
+ // TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl.
+ TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error)
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -556,7 +566,11 @@ func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length ui
if !fd.IsWritable() {
return syserror.EBADF
}
- return fd.impl.Allocate(ctx, mode, offset, length)
+ if err := fd.impl.Allocate(ctx, mode, offset, length); err != nil {
+ return err
+ }
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent)
+ return nil
}
// Readiness implements waiter.Waitable.Readiness.
@@ -592,6 +606,9 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
}
start := fsmetric.StartReadWait()
n, err := fd.impl.PRead(ctx, dst, offset, opts)
+ if n > 0 {
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent)
+ }
fsmetric.Reads.Increment()
fsmetric.FinishReadWait(fsmetric.ReadWait, start)
return n, err
@@ -604,6 +621,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt
}
start := fsmetric.StartReadWait()
n, err := fd.impl.Read(ctx, dst, opts)
+ if n > 0 {
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent)
+ }
fsmetric.Reads.Increment()
fsmetric.FinishReadWait(fsmetric.ReadWait, start)
return n, err
@@ -619,7 +639,11 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o
if !fd.writable {
return 0, syserror.EBADF
}
- return fd.impl.PWrite(ctx, src, offset, opts)
+ n, err := fd.impl.PWrite(ctx, src, offset, opts)
+ if n > 0 {
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent)
+ }
+ return n, err
}
// Write is similar to PWrite, but does not specify an offset.
@@ -627,7 +651,11 @@ func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, op
if !fd.writable {
return 0, syserror.EBADF
}
- return fd.impl.Write(ctx, src, opts)
+ n, err := fd.impl.Write(ctx, src, opts)
+ if n > 0 {
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent)
+ }
+ return n, err
}
// IterDirents invokes cb on each entry in the directory represented by fd. If
@@ -791,9 +819,9 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e
}
// LockBSD tries to acquire a BSD-style advisory file lock.
-func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error {
+func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error {
atomic.StoreUint32(&fd.usedLockBSD, 1)
- return fd.impl.LockBSD(ctx, fd, lockType, blocker)
+ return fd.impl.LockBSD(ctx, fd, ownerPID, lockType, blocker)
}
// UnlockBSD releases a BSD-style advisory file lock.
@@ -802,13 +830,45 @@ func (fd *FileDescription) UnlockBSD(ctx context.Context) error {
}
// LockPOSIX locks a POSIX-style file range lock.
-func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error {
- return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block)
+func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error {
+ return fd.impl.LockPOSIX(ctx, uid, ownerPID, t, r, block)
}
// UnlockPOSIX unlocks a POSIX-style file range lock.
-func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error {
- return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence)
+func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, r lock.LockRange) error {
+ return fd.impl.UnlockPOSIX(ctx, uid, r)
+}
+
+// TestPOSIX returns information about whether the specified lock can be held.
+func (fd *FileDescription) TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error) {
+ return fd.impl.TestPOSIX(ctx, uid, t, r)
+}
+
+// ComputeLockRange computes the range of a file lock based on the given values.
+func (fd *FileDescription) ComputeLockRange(ctx context.Context, start uint64, length uint64, whence int16) (lock.LockRange, error) {
+ var off int64
+ switch whence {
+ case linux.SEEK_SET:
+ off = 0
+ case linux.SEEK_CUR:
+ // Note that Linux does not hold any mutexes while retrieving the file
+ // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR)
+ if err != nil {
+ return lock.LockRange{}, err
+ }
+ off = curOff
+ case linux.SEEK_END:
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE})
+ if err != nil {
+ return lock.LockRange{}, err
+ }
+ off = int64(stat.Size)
+ default:
+ return lock.LockRange{}, syserror.EINVAL
+ }
+
+ return lock.ComputeRange(int64(start), int64(length), off)
}
// A FileAsync sends signals to its owner when w is ready for IO. This is only
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 48ca9de44..eb7d2fd3b 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -419,8 +419,8 @@ func (fd *LockFD) Locks() *FileLocks {
}
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
-func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
- return fd.locks.LockBSD(uid, t, block)
+func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
+ return fd.locks.LockBSD(ctx, uid, ownerPID, t, block)
}
// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
@@ -429,6 +429,21 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
return nil
}
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
+ return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
+ return fd.locks.UnlockPOSIX(ctx, uid, r)
+}
+
+// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX.
+func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
+ return fd.locks.TestPOSIX(ctx, uid, t, r)
+}
+
// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
// returning ENOLCK.
//
@@ -436,7 +451,7 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
type NoLockFD struct{}
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
-func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return syserror.ENOLCK
}
@@ -446,11 +461,16 @@ func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
}
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
return syserror.ENOLCK
}
// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
return syserror.ENOLCK
}
+
+// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX.
+func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
+ return linux.Flock{}, syserror.ENOLCK
+}
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index 1ff202f2a..cbe4d8c2d 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -39,8 +39,8 @@ type FileLocks struct {
}
// LockBSD tries to acquire a BSD-style lock on the entire file.
-func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
- if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) {
+func (fl *FileLocks) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerID int32, t fslock.LockType, block fslock.Blocker) error {
+ if fl.bsd.LockRegion(uid, ownerID, t, fslock.LockRange{0, fslock.LockEOF}, block) {
return nil
}
@@ -61,12 +61,8 @@ func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) {
}
// LockPOSIX tries to acquire a POSIX-style lock on a file region.
-func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- rng, err := computeRange(ctx, fd, start, length, whence)
- if err != nil {
- return err
- }
- if fl.posix.LockRegion(uid, t, rng, block) {
+func (fl *FileLocks) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
+ if fl.posix.LockRegion(uid, ownerPID, t, r, block) {
return nil
}
@@ -82,37 +78,12 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fsl
//
// This operation is always successful, even if there did not exist a lock on
// the requested region held by uid in the first place.
-func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error {
- rng, err := computeRange(ctx, fd, start, length, whence)
- if err != nil {
- return err
- }
- fl.posix.UnlockRegion(uid, rng)
+func (fl *FileLocks) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
+ fl.posix.UnlockRegion(uid, r)
return nil
}
-func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) {
- var off int64
- switch whence {
- case linux.SEEK_SET:
- off = 0
- case linux.SEEK_CUR:
- // Note that Linux does not hold any mutexes while retrieving the file
- // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
- curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR)
- if err != nil {
- return fslock.LockRange{}, err
- }
- off = curOff
- case linux.SEEK_END:
- stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE})
- if err != nil {
- return fslock.LockRange{}, err
- }
- off = int64(stat.Size)
- default:
- return fslock.LockRange{}, syserror.EINVAL
- }
-
- return fslock.ComputeRange(int64(start), int64(length), off)
+// TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl.
+func (fl *FileLocks) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
+ return fl.posix.TestRegion(ctx, uid, t, r), nil
}
diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go
new file mode 100644
index 000000000..39fbac987
--- /dev/null
+++ b/pkg/sentry/vfs/opath.go
@@ -0,0 +1,139 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH.
+//
+// +stateify savable
+type opathFD struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *opathFD) Release(context.Context) {
+ // noop
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.EBADF
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.EBADF
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return syserror.EBADF
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return syserror.EBADF
+}
+
+// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
+func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ return nil, syserror.EBADF
+}
+
+// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
+func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) {
+ return "", syserror.EBADF
+}
+
+// SetXattr implements vfs.FileDescriptionImpl.SetXattr.
+func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error {
+ return syserror.EBADF
+}
+
+// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr.
+func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error {
+ return syserror.EBADF
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *opathFD) Sync(ctx context.Context) error {
+ return syserror.EBADF
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error {
+ return syserror.EBADF
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ vfsObj := fd.vfsfd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vfsfd.vd,
+ Start: fd.vfsfd.vd,
+ })
+ stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(ctx, rp)
+ return stat, err
+}
+
+// StatFS returns metadata for the filesystem containing the file represented
+// by fd.
+func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) {
+ vfsObj := fd.vfsfd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vfsfd.vd,
+ Start: fd.vfsfd.vd,
+ })
+ statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp)
+ vfsObj.putResolvingPath(ctx, rp)
+ return statfs, err
+}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index bc79e5ecc..c9907843c 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -129,7 +129,7 @@ type OpenOptions struct {
//
// FilesystemImpls are responsible for implementing the following flags:
// O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC,
- // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and
+ // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_SYNC, O_TMPFILE, and
// O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and
// O_NOFOLLOW. VFS users are responsible for handling O_CLOEXEC, since file
// descriptors are mostly outside the scope of VFS.
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 6fd1bb0b2..0aff2dd92 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -425,6 +425,18 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
rp.mustBeDir = true
rp.mustBeDirOrig = true
}
+ if opts.Flags&linux.O_PATH != 0 {
+ vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{})
+ if err != nil {
+ return nil, err
+ }
+ fd := &opathFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ vd.DecRef(ctx)
+ return &fd.vfsfd, err
+ }
for {
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 28e62abbb..2e2395807 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -67,6 +67,7 @@ go_library(
],
marshal = False,
stateify = False,
+ visibility = ["//:sandbox"],
deps = [
"//pkg/goid",
],
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index cb8981633..a6a91e064 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -23,95 +23,114 @@ import (
// Mapping for tcpip.Error types.
var (
- ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
- ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV)
- ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
- ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
- ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
- ErrDuplicateAddress = New(tcpip.ErrDuplicateAddress.String(), linux.EEXIST)
- ErrBadLinkEndpoint = New(tcpip.ErrBadLinkEndpoint.String(), linux.EINVAL)
- ErrAlreadyBound = New(tcpip.ErrAlreadyBound.String(), linux.EINVAL)
- ErrInvalidEndpointState = New(tcpip.ErrInvalidEndpointState.String(), linux.EINVAL)
- ErrAlreadyConnecting = New(tcpip.ErrAlreadyConnecting.String(), linux.EALREADY)
- ErrNoPortAvailable = New(tcpip.ErrNoPortAvailable.String(), linux.EAGAIN)
- ErrPortInUse = New(tcpip.ErrPortInUse.String(), linux.EADDRINUSE)
- ErrBadLocalAddress = New(tcpip.ErrBadLocalAddress.String(), linux.EADDRNOTAVAIL)
- ErrClosedForSend = New(tcpip.ErrClosedForSend.String(), linux.EPIPE)
- ErrClosedForReceive = New(tcpip.ErrClosedForReceive.String(), nil)
- ErrTimeout = New(tcpip.ErrTimeout.String(), linux.ETIMEDOUT)
- ErrAborted = New(tcpip.ErrAborted.String(), linux.EPIPE)
- ErrConnectStarted = New(tcpip.ErrConnectStarted.String(), linux.EINPROGRESS)
- ErrDestinationRequired = New(tcpip.ErrDestinationRequired.String(), linux.EDESTADDRREQ)
- ErrNotSupported = New(tcpip.ErrNotSupported.String(), linux.EOPNOTSUPP)
- ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY)
- ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT)
- ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
- ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
- ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
- ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT)
+ ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL)
+ ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV)
+ ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV)
+ ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT)
+ ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST)
+ ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST)
+ ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL)
+ ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL)
+ ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY)
+ ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN)
+ ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE)
+ ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL)
+ ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE)
+ ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil)
+ ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT)
+ ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE)
+ ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS)
+ ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ)
+ ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP)
+ ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY)
+ ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT)
+ ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL)
+ ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES)
+ ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM)
+ ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT)
)
-var netstackErrorTranslations map[string]*Error
-
-func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) {
- key := tcpipErr.String()
- if _, ok := netstackErrorTranslations[key]; ok {
- panic(fmt.Sprintf("duplicate error key: %s", key))
- }
- netstackErrorTranslations[key] = netstackErr
-}
-
-func init() {
- netstackErrorTranslations = make(map[string]*Error)
- addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol)
- addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID)
- addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice)
- addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption)
- addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID)
- addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress)
- addErrMapping(tcpip.ErrNoRoute, ErrNoRoute)
- addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint)
- addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound)
- addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState)
- addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting)
- addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected)
- addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable)
- addErrMapping(tcpip.ErrPortInUse, ErrPortInUse)
- addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress)
- addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend)
- addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive)
- addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock)
- addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused)
- addErrMapping(tcpip.ErrTimeout, ErrTimeout)
- addErrMapping(tcpip.ErrAborted, ErrAborted)
- addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted)
- addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired)
- addErrMapping(tcpip.ErrNotSupported, ErrNotSupported)
- addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported)
- addErrMapping(tcpip.ErrNotConnected, ErrNotConnected)
- addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset)
- addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted)
- addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile)
- addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue)
- addErrMapping(tcpip.ErrBadAddress, ErrBadAddress)
- addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable)
- addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong)
- addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace)
- addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
- addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
- addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
- addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer)
-}
-
// TranslateNetstackError converts an error from the tcpip package to a sentry
// internal error.
-func TranslateNetstackError(err *tcpip.Error) *Error {
- if err == nil {
+func TranslateNetstackError(err tcpip.Error) *Error {
+ switch err.(type) {
+ case nil:
return nil
+ case *tcpip.ErrUnknownProtocol:
+ return ErrUnknownProtocol
+ case *tcpip.ErrUnknownNICID:
+ return ErrUnknownNICID
+ case *tcpip.ErrUnknownDevice:
+ return ErrUnknownDevice
+ case *tcpip.ErrUnknownProtocolOption:
+ return ErrUnknownProtocolOption
+ case *tcpip.ErrDuplicateNICID:
+ return ErrDuplicateNICID
+ case *tcpip.ErrDuplicateAddress:
+ return ErrDuplicateAddress
+ case *tcpip.ErrNoRoute:
+ return ErrNoRoute
+ case *tcpip.ErrAlreadyBound:
+ return ErrAlreadyBound
+ case *tcpip.ErrInvalidEndpointState:
+ return ErrInvalidEndpointState
+ case *tcpip.ErrAlreadyConnecting:
+ return ErrAlreadyConnecting
+ case *tcpip.ErrAlreadyConnected:
+ return ErrAlreadyConnected
+ case *tcpip.ErrNoPortAvailable:
+ return ErrNoPortAvailable
+ case *tcpip.ErrPortInUse:
+ return ErrPortInUse
+ case *tcpip.ErrBadLocalAddress:
+ return ErrBadLocalAddress
+ case *tcpip.ErrClosedForSend:
+ return ErrClosedForSend
+ case *tcpip.ErrClosedForReceive:
+ return ErrClosedForReceive
+ case *tcpip.ErrWouldBlock:
+ return ErrWouldBlock
+ case *tcpip.ErrConnectionRefused:
+ return ErrConnectionRefused
+ case *tcpip.ErrTimeout:
+ return ErrTimeout
+ case *tcpip.ErrAborted:
+ return ErrAborted
+ case *tcpip.ErrConnectStarted:
+ return ErrConnectStarted
+ case *tcpip.ErrDestinationRequired:
+ return ErrDestinationRequired
+ case *tcpip.ErrNotSupported:
+ return ErrNotSupported
+ case *tcpip.ErrQueueSizeNotSupported:
+ return ErrQueueSizeNotSupported
+ case *tcpip.ErrNotConnected:
+ return ErrNotConnected
+ case *tcpip.ErrConnectionReset:
+ return ErrConnectionReset
+ case *tcpip.ErrConnectionAborted:
+ return ErrConnectionAborted
+ case *tcpip.ErrNoSuchFile:
+ return ErrNoSuchFile
+ case *tcpip.ErrInvalidOptionValue:
+ return ErrInvalidOptionValue
+ case *tcpip.ErrBadAddress:
+ return ErrBadAddress
+ case *tcpip.ErrNetworkUnreachable:
+ return ErrNetworkUnreachable
+ case *tcpip.ErrMessageTooLong:
+ return ErrMessageTooLong
+ case *tcpip.ErrNoBufferSpace:
+ return ErrNoBufferSpace
+ case *tcpip.ErrBroadcastDisabled:
+ return ErrBroadcastDisabled
+ case *tcpip.ErrNotPermitted:
+ return ErrNotPermittedNet
+ case *tcpip.ErrAddressFamilyNotSupported:
+ return ErrAddressFamilyNotSupported
+ case *tcpip.ErrBadBuffer:
+ return ErrBadBuffer
+ default:
+ panic(fmt.Sprintf("unknown error %T", err))
}
- se, ok := netstackErrorTranslations[err.String()]
- if !ok {
- panic("Unknown error: " + err.String())
- }
- return se
}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index e7924e5c2..f979d22f0 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -18,6 +18,7 @@ go_template_instance(
go_library(
name = "tcpip",
srcs = [
+ "errors.go",
"sock_err_list.go",
"socketops.go",
"tcpip.go",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index fdeec12d3..c188aaa18 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -16,6 +16,7 @@
package gonet
import (
+ "bytes"
"context"
"errors"
"io"
@@ -247,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
l.wq.EventRegister(&waitEntry, waiter.EventIn)
@@ -256,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
for {
n, wq, err = l.ep.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
@@ -297,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
res, err := ep.Read(&w, opts)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
res, err = ep.Read(&w, opts)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
@@ -315,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
}
}
- if err == tcpip.ErrClosedForReceive {
+ if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
return 0, io.EOF
}
@@ -354,10 +355,8 @@ func (c *TCPConn) Write(b []byte) (int, error) {
default:
}
- v := buffer.NewViewFromBytes(b)
-
// We must handle two soft failure conditions simultaneously:
- // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // 1. Write may write nothing and return *tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
// not already and wait to try again.
// 2. Write may write fewer than the full number of bytes and return
@@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) {
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
- err *tcpip.Error
- nbytes int
- reg bool
- notifyCh chan struct{}
+ r bytes.Reader
+ nbytes int
+ entry waiter.Entry
+ ch <-chan struct{}
)
- for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
- if err == tcpip.ErrWouldBlock {
- if !reg {
- // Only register once.
- reg = true
-
- // Create wait queue entry that notifies a channel.
- var waitEntry waiter.Entry
- waitEntry, notifyCh = waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
+ for nbytes != len(b) {
+ r.Reset(b[nbytes:])
+ n, err := c.ep.Write(&r, tcpip.WriteOptions{})
+ nbytes += int(n)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrWouldBlock:
+ if ch == nil {
+ entry, ch = waiter.NewChannelEntry(nil)
+
+ c.wq.EventRegister(&entry, waiter.EventOut)
+ defer c.wq.EventUnregister(&entry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
@@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) {
select {
case <-deadline:
return nbytes, c.newOpError("write", &timeoutError{})
- case <-notifyCh:
+ case <-ch:
+ continue
}
}
+ default:
+ return nbytes, c.newOpError("write", errors.New(err.String()))
}
-
- var n int64
- n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
-
- if err == nil {
- return nbytes, nil
}
-
- return nbytes, c.newOpError("write", errors.New(err.String()))
+ return nbytes, nil
}
// Close implements net.Conn.Close.
@@ -502,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
ep.Close()
@@ -644,17 +637,19 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// If we're being called by Write, there is no addr
- wopts := tcpip.WriteOptions{}
+ writeOptions := tcpip.WriteOptions{}
if addr != nil {
ua := addr.(*net.UDPAddr)
- wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+ writeOptions.To = &tcpip.FullAddress{
+ Addr: tcpip.Address(ua.IP),
+ Port: uint16(ua.Port),
+ }
}
- v := buffer.NewView(len(b))
- copy(v, b)
-
- n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err == tcpip.ErrWouldBlock {
+ var r bytes.Reader
+ r.Reset(b)
+ n, err := c.ep.Write(&r, writeOptions)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&waitEntry, waiter.EventOut)
@@ -666,8 +661,8 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
case <-notifyCh:
}
- n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err != tcpip.ErrWouldBlock {
+ n, err = c.ep.Write(&r, writeOptions)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index b196324c7..2b3ea4bdf 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) {
}
}
-func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+func newLoopbackStack() (*stack.Stack, tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
@@ -94,7 +94,7 @@ type testConnection struct {
ep tcpip.Endpoint
}
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
@@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er
wq.EventRegister(&entry, waiter.EventOut)
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
<-ch
err = ep.LastError()
}
@@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrNoRoute
- if !ok || got.Err.Error() != want.String() {
- t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) {
+ case *net.OpError:
+ if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() {
+ t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{})
+ }
+ default:
+ t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{})
}
}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 0ac2000ca..07b4393a4 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -234,7 +234,7 @@ func IPv4RouterAlert() NetworkChecker {
for {
opt, done, err := iterator.Next()
if err != nil {
- t.Fatalf("error acquiring next IPv4 option %s", err)
+ t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
}
if done {
break
diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go
new file mode 100644
index 000000000..af46da1d2
--- /dev/null
+++ b/pkg/tcpip/errors.go
@@ -0,0 +1,538 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+)
+
+// Error represents an error in the netstack error space.
+//
+// The error interface is intentionally omitted to avoid loss of type
+// information that would occur if these errors were passed as error.
+type Error interface {
+ isError()
+
+ // IgnoreStats indicates whether this error should be included in failure
+ // counts in tcpip.Stats structs.
+ IgnoreStats() bool
+
+ fmt.Stringer
+}
+
+// ErrAborted indicates the operation was aborted.
+//
+// +stateify savable
+type ErrAborted struct{}
+
+func (*ErrAborted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAborted) IgnoreStats() bool {
+ return false
+}
+func (*ErrAborted) String() string {
+ return "operation aborted"
+}
+
+// ErrAddressFamilyNotSupported indicates the operation does not support the
+// given address family.
+//
+// +stateify savable
+type ErrAddressFamilyNotSupported struct{}
+
+func (*ErrAddressFamilyNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAddressFamilyNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrAddressFamilyNotSupported) String() string {
+ return "address family not supported by protocol"
+}
+
+// ErrAlreadyBound indicates the endpoint is already bound.
+//
+// +stateify savable
+type ErrAlreadyBound struct{}
+
+func (*ErrAlreadyBound) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyBound) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyBound) String() string { return "endpoint already bound" }
+
+// ErrAlreadyConnected indicates the endpoint is already connected.
+//
+// +stateify savable
+type ErrAlreadyConnected struct{}
+
+func (*ErrAlreadyConnected) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyConnected) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" }
+
+// ErrAlreadyConnecting indicates the endpoint is already connecting.
+//
+// +stateify savable
+type ErrAlreadyConnecting struct{}
+
+func (*ErrAlreadyConnecting) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyConnecting) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" }
+
+// ErrBadAddress indicates a bad address was provided.
+//
+// +stateify savable
+type ErrBadAddress struct{}
+
+func (*ErrBadAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadAddress) String() string { return "bad address" }
+
+// ErrBadBuffer indicates a bad buffer was provided.
+//
+// +stateify savable
+type ErrBadBuffer struct{}
+
+func (*ErrBadBuffer) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadBuffer) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadBuffer) String() string { return "bad buffer" }
+
+// ErrBadLocalAddress indicates a bad local address was provided.
+//
+// +stateify savable
+type ErrBadLocalAddress struct{}
+
+func (*ErrBadLocalAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadLocalAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadLocalAddress) String() string { return "bad local address" }
+
+// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint.
+//
+// +stateify savable
+type ErrBroadcastDisabled struct{}
+
+func (*ErrBroadcastDisabled) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBroadcastDisabled) IgnoreStats() bool {
+ return false
+}
+func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" }
+
+// ErrClosedForReceive indicates the endpoint is closed for incoming data.
+//
+// +stateify savable
+type ErrClosedForReceive struct{}
+
+func (*ErrClosedForReceive) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrClosedForReceive) IgnoreStats() bool {
+ return false
+}
+func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" }
+
+// ErrClosedForSend indicates the endpoint is closed for outgoing data.
+//
+// +stateify savable
+type ErrClosedForSend struct{}
+
+func (*ErrClosedForSend) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrClosedForSend) IgnoreStats() bool {
+ return false
+}
+func (*ErrClosedForSend) String() string { return "endpoint is closed for send" }
+
+// ErrConnectStarted indicates the endpoint is connecting asynchronously.
+//
+// +stateify savable
+type ErrConnectStarted struct{}
+
+func (*ErrConnectStarted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectStarted) IgnoreStats() bool {
+ return true
+}
+func (*ErrConnectStarted) String() string { return "connection attempt started" }
+
+// ErrConnectionAborted indicates the connection was aborted.
+//
+// +stateify savable
+type ErrConnectionAborted struct{}
+
+func (*ErrConnectionAborted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionAborted) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionAborted) String() string { return "connection aborted" }
+
+// ErrConnectionRefused indicates the connection was refused.
+//
+// +stateify savable
+type ErrConnectionRefused struct{}
+
+func (*ErrConnectionRefused) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionRefused) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionRefused) String() string { return "connection was refused" }
+
+// ErrConnectionReset indicates the connection was reset.
+//
+// +stateify savable
+type ErrConnectionReset struct{}
+
+func (*ErrConnectionReset) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionReset) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionReset) String() string { return "connection reset by peer" }
+
+// ErrDestinationRequired indicates the operation requires a destination
+// address, and one was not provided.
+//
+// +stateify savable
+type ErrDestinationRequired struct{}
+
+func (*ErrDestinationRequired) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDestinationRequired) IgnoreStats() bool {
+ return false
+}
+func (*ErrDestinationRequired) String() string { return "destination address is required" }
+
+// ErrDuplicateAddress indicates the operation encountered a duplicate address.
+//
+// +stateify savable
+type ErrDuplicateAddress struct{}
+
+func (*ErrDuplicateAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDuplicateAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrDuplicateAddress) String() string { return "duplicate address" }
+
+// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID.
+//
+// +stateify savable
+type ErrDuplicateNICID struct{}
+
+func (*ErrDuplicateNICID) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDuplicateNICID) IgnoreStats() bool {
+ return false
+}
+func (*ErrDuplicateNICID) String() string { return "duplicate nic id" }
+
+// ErrInvalidEndpointState indicates the endpoint is in an invalid state.
+//
+// +stateify savable
+type ErrInvalidEndpointState struct{}
+
+func (*ErrInvalidEndpointState) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidEndpointState) IgnoreStats() bool {
+ return false
+}
+func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" }
+
+// ErrInvalidOptionValue indicates an invalid option value was provided.
+//
+// +stateify savable
+type ErrInvalidOptionValue struct{}
+
+func (*ErrInvalidOptionValue) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidOptionValue) IgnoreStats() bool {
+ return false
+}
+func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" }
+
+// ErrMalformedHeader indicates the operation encountered a malformed header.
+//
+// +stateify savable
+type ErrMalformedHeader struct{}
+
+func (*ErrMalformedHeader) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrMalformedHeader) IgnoreStats() bool {
+ return false
+}
+func (*ErrMalformedHeader) String() string { return "header is malformed" }
+
+// ErrMessageTooLong indicates the operation encountered a message whose length
+// exceeds the maximum permitted.
+//
+// +stateify savable
+type ErrMessageTooLong struct{}
+
+func (*ErrMessageTooLong) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrMessageTooLong) IgnoreStats() bool {
+ return false
+}
+func (*ErrMessageTooLong) String() string { return "message too long" }
+
+// ErrNetworkUnreachable indicates the operation is not able to reach the
+// destination network.
+//
+// +stateify savable
+type ErrNetworkUnreachable struct{}
+
+func (*ErrNetworkUnreachable) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNetworkUnreachable) IgnoreStats() bool {
+ return false
+}
+func (*ErrNetworkUnreachable) String() string { return "network is unreachable" }
+
+// ErrNoBufferSpace indicates no buffer space is available.
+//
+// +stateify savable
+type ErrNoBufferSpace struct{}
+
+func (*ErrNoBufferSpace) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoBufferSpace) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoBufferSpace) String() string { return "no buffer space available" }
+
+// ErrNoPortAvailable indicates no port could be allocated for the operation.
+//
+// +stateify savable
+type ErrNoPortAvailable struct{}
+
+func (*ErrNoPortAvailable) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoPortAvailable) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoPortAvailable) String() string { return "no ports are available" }
+
+// ErrNoRoute indicates the operation is not able to find a route to the
+// destination.
+//
+// +stateify savable
+type ErrNoRoute struct{}
+
+func (*ErrNoRoute) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoRoute) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoRoute) String() string { return "no route" }
+
+// ErrNoSuchFile is used to indicate that ENOENT should be returned the to
+// calling application.
+//
+// +stateify savable
+type ErrNoSuchFile struct{}
+
+func (*ErrNoSuchFile) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoSuchFile) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoSuchFile) String() string { return "no such file" }
+
+// ErrNotConnected indicates the endpoint is not connected.
+//
+// +stateify savable
+type ErrNotConnected struct{}
+
+func (*ErrNotConnected) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotConnected) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotConnected) String() string { return "endpoint not connected" }
+
+// ErrNotPermitted indicates the operation is not permitted.
+//
+// +stateify savable
+type ErrNotPermitted struct{}
+
+func (*ErrNotPermitted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotPermitted) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotPermitted) String() string { return "operation not permitted" }
+
+// ErrNotSupported indicates the operation is not supported.
+//
+// +stateify savable
+type ErrNotSupported struct{}
+
+func (*ErrNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotSupported) String() string { return "operation not supported" }
+
+// ErrPortInUse indicates the provided port is in use.
+//
+// +stateify savable
+type ErrPortInUse struct{}
+
+func (*ErrPortInUse) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrPortInUse) IgnoreStats() bool {
+ return false
+}
+func (*ErrPortInUse) String() string { return "port is in use" }
+
+// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size
+// operation.
+//
+// +stateify savable
+type ErrQueueSizeNotSupported struct{}
+
+func (*ErrQueueSizeNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrQueueSizeNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" }
+
+// ErrTimeout indicates the operation timed out.
+//
+// +stateify savable
+type ErrTimeout struct{}
+
+func (*ErrTimeout) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrTimeout) IgnoreStats() bool {
+ return false
+}
+func (*ErrTimeout) String() string { return "operation timed out" }
+
+// ErrUnknownDevice indicates an unknown device identifier was provided.
+//
+// +stateify savable
+type ErrUnknownDevice struct{}
+
+func (*ErrUnknownDevice) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownDevice) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownDevice) String() string { return "unknown device" }
+
+// ErrUnknownNICID indicates an unknown NIC ID was provided.
+//
+// +stateify savable
+type ErrUnknownNICID struct{}
+
+func (*ErrUnknownNICID) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownNICID) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownNICID) String() string { return "unknown nic id" }
+
+// ErrUnknownProtocol indicates an unknown protocol was requested.
+//
+// +stateify savable
+type ErrUnknownProtocol struct{}
+
+func (*ErrUnknownProtocol) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownProtocol) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownProtocol) String() string { return "unknown protocol" }
+
+// ErrUnknownProtocolOption indicates an unknown protocol option was provided.
+//
+// +stateify savable
+type ErrUnknownProtocolOption struct{}
+
+func (*ErrUnknownProtocolOption) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownProtocolOption) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" }
+
+// ErrWouldBlock indicates the operation would block.
+//
+// +stateify savable
+type ErrWouldBlock struct{}
+
+func (*ErrWouldBlock) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrWouldBlock) IgnoreStats() bool {
+ return true
+}
+func (*ErrWouldBlock) String() string { return "operation would block" }
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
index 114d43df3..bb9d44aff 100644
--- a/pkg/tcpip/faketime/BUILD
+++ b/pkg/tcpip/faketime/BUILD
@@ -6,10 +6,7 @@ go_library(
name = "faketime",
srcs = ["faketime.go"],
visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "@com_github_dpjacques_clockwork//:go_default_library",
- ],
+ deps = ["//pkg/tcpip"],
)
go_test(
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
index f7a4fbde1..fb819d7a8 100644
--- a/pkg/tcpip/faketime/faketime.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -17,10 +17,10 @@ package faketime
import (
"container/heap"
+ "fmt"
"sync"
"time"
- "github.com/dpjacques/clockwork"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -44,38 +44,85 @@ func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer {
return nil
}
+type notificationChannels struct {
+ mu struct {
+ sync.Mutex
+
+ ch []<-chan struct{}
+ }
+}
+
+func (n *notificationChannels) add(ch <-chan struct{}) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ n.mu.ch = append(n.mu.ch, ch)
+}
+
+// wait returns once all the notification channels are readable.
+//
+// Channels that are added while waiting on existing channels will be waited on
+// as well.
+func (n *notificationChannels) wait() {
+ for {
+ n.mu.Lock()
+ ch := n.mu.ch
+ n.mu.ch = nil
+ n.mu.Unlock()
+
+ if len(ch) == 0 {
+ break
+ }
+
+ for _, c := range ch {
+ <-c
+ }
+ }
+}
+
// ManualClock implements tcpip.Clock and only advances manually with Advance
// method.
type ManualClock struct {
- clock clockwork.FakeClock
+ // runningTimers tracks the completion of timer callbacks that began running
+ // immediately upon their scheduling. It is used to ensure the proper ordering
+ // of timer callback dispatch.
+ runningTimers notificationChannels
+
+ mu struct {
+ sync.RWMutex
- // mu protects the fields below.
- mu sync.RWMutex
+ // now is the current (fake) time of the clock.
+ now time.Time
- // times is min-heap of times. A heap is used for quick retrieval of the next
- // upcoming time of scheduled work.
- times *timeHeap
+ // times is min-heap of times.
+ times timeHeap
- // waitGroups stores one WaitGroup for all work scheduled to execute at the
- // same time via AfterFunc. This allows parallel execution of all functions
- // passed to AfterFunc scheduled for the same time.
- waitGroups map[time.Time]*sync.WaitGroup
+ // timers holds the timers scheduled for each time.
+ timers map[time.Time]map[*manualTimer]struct{}
+ }
}
// NewManualClock creates a new ManualClock instance.
func NewManualClock() *ManualClock {
- return &ManualClock{
- clock: clockwork.NewFakeClock(),
- times: &timeHeap{},
- waitGroups: make(map[time.Time]*sync.WaitGroup),
- }
+ c := &ManualClock{}
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Set the initial time to a non-zero value since the zero value is used to
+ // detect inactive timers.
+ c.mu.now = time.Unix(0, 0)
+ c.mu.timers = make(map[time.Time]map[*manualTimer]struct{})
+
+ return c
}
var _ tcpip.Clock = (*ManualClock)(nil)
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
func (mc *ManualClock) NowNanoseconds() int64 {
- return mc.clock.Now().UnixNano()
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
+ return mc.mu.now.UnixNano()
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
@@ -85,128 +132,203 @@ func (mc *ManualClock) NowMonotonic() int64 {
// AfterFunc implements tcpip.Clock.AfterFunc.
func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- until := mc.clock.Now().Add(d)
- wg := mc.addWait(until)
- return &manualTimer{
+ mt := &manualTimer{
clock: mc,
- until: until,
- timer: mc.clock.AfterFunc(d, func() {
- defer wg.Done()
- f()
- }),
+ f: f,
}
-}
-// addWait adds an additional wait to the WaitGroup for parallel execution of
-// all work scheduled for t. Returns a reference to the WaitGroup modified.
-func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
- mc.mu.RLock()
- wg, ok := mc.waitGroups[t]
- mc.mu.RUnlock()
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
- if ok {
- wg.Add(1)
- return wg
+ mc.resetTimerLocked(mt, d)
+ return mt
+}
+
+// resetTimerLocked schedules a timer to be fired after the given duration.
+//
+// Precondition: mc.mu and mt.mu must be locked.
+func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) {
+ if !mt.mu.firesAt.IsZero() {
+ panic("tried to reset an active timer")
}
- mc.mu.Lock()
- heap.Push(mc.times, t)
- mc.mu.Unlock()
+ t := mc.mu.now.Add(d)
- wg = &sync.WaitGroup{}
- wg.Add(1)
+ if !mc.mu.now.Before(t) {
+ // If the timer is scheduled to fire immediately, call its callback
+ // in a new goroutine immediately.
+ //
+ // It needs to be called in its own goroutine to escape its current
+ // execution context - like an actual timer.
+ ch := make(chan struct{})
+ mc.runningTimers.add(ch)
- mc.mu.Lock()
- mc.waitGroups[t] = wg
- mc.mu.Unlock()
+ go func() {
+ defer close(ch)
+
+ mt.f()
+ }()
- return wg
+ return
+ }
+
+ mt.mu.firesAt = t
+
+ timers, ok := mc.mu.timers[t]
+ if !ok {
+ timers = make(map[*manualTimer]struct{})
+ mc.mu.timers[t] = timers
+ heap.Push(&mc.mu.times, t)
+ }
+
+ timers[mt] = struct{}{}
}
-// removeWait removes a wait from the WaitGroup for parallel execution of all
-// work scheduled for t.
-func (mc *ManualClock) removeWait(t time.Time) {
- mc.mu.RLock()
- defer mc.mu.RUnlock()
+// stopTimerLocked stops a timer from firing.
+//
+// Precondition: mc.mu and mt.mu must be locked.
+func (mc *ManualClock) stopTimerLocked(mt *manualTimer) {
+ t := mt.mu.firesAt
+ mt.mu.firesAt = time.Time{}
+
+ if t.IsZero() {
+ panic("tried to stop an inactive timer")
+ }
- wg := mc.waitGroups[t]
- wg.Done()
+ timers, ok := mc.mu.timers[t]
+ if !ok {
+ err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt)
+ for t := range mc.mu.timers {
+ err += fmt.Sprintf("%s\n", t.UTC())
+ }
+ panic(err)
+ }
+
+ if _, ok := timers[mt]; !ok {
+ panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC()))
+ }
+
+ delete(timers, mt)
+
+ if len(timers) == 0 {
+ delete(mc.mu.timers, t)
+ }
}
// Advance executes all work that have been scheduled to execute within d from
-// the current time. Blocks until all work has completed execution.
+// the current time. Blocks until all work has completed execution.
func (mc *ManualClock) Advance(d time.Duration) {
- // Block until all the work is done
- until := mc.clock.Now().Add(d)
- for {
- mc.mu.Lock()
- if mc.times.Len() == 0 {
- mc.mu.Unlock()
- break
- }
+ // We spawn goroutines for timers that were scheduled to fire at the time of
+ // being reset. Wait for those goroutines to complete before proceeding so
+ // that timer callbacks are called in the right order.
+ mc.runningTimers.wait()
- t := heap.Pop(mc.times).(time.Time)
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ until := mc.mu.now.Add(d)
+ for mc.mu.times.Len() > 0 {
+ t := heap.Pop(&mc.mu.times).(time.Time)
if t.After(until) {
// No work to do
- heap.Push(mc.times, t)
- mc.mu.Unlock()
+ heap.Push(&mc.mu.times, t)
break
}
- mc.mu.Unlock()
- diff := t.Sub(mc.clock.Now())
- mc.clock.Advance(diff)
+ timers := mc.mu.timers[t]
+ delete(mc.mu.timers, t)
+
+ mc.mu.now = t
+
+ // Mark the timers as inactive since they will be fired.
+ //
+ // This needs to be done while holding mc's lock because we remove the entry
+ // in the map of timers for the current time. If an attempt to stop a
+ // timer is made after mc's lock was dropped but before the timer is
+ // marked inactive, we would panic since no entry exists for the time when
+ // the timer was expected to fire.
+ for mt := range timers {
+ mt.mu.Lock()
+ mt.mu.firesAt = time.Time{}
+ mt.mu.Unlock()
+ }
- mc.mu.RLock()
- wg := mc.waitGroups[t]
- mc.mu.RUnlock()
+ // Release the lock before calling the timer's callback fn since the
+ // callback fn might try to schedule a timer which requires obtaining
+ // mc's lock.
+ mc.mu.Unlock()
- wg.Wait()
+ for mt := range timers {
+ mt.f()
+ }
+ // The timer callbacks may have scheduled a timer to fire immediately.
+ // We spawn goroutines for these timers and need to wait for them to
+ // finish before proceeding so that timer callbacks are called in the
+ // right order.
+ mc.runningTimers.wait()
mc.mu.Lock()
- delete(mc.waitGroups, t)
- mc.mu.Unlock()
}
- if now := mc.clock.Now(); until.After(now) {
- mc.clock.Advance(until.Sub(now))
+
+ mc.mu.now = until
+}
+
+func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) {
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
+
+ if !mt.mu.firesAt.IsZero() {
+ mc.stopTimerLocked(mt)
}
+
+ mc.resetTimerLocked(mt, d)
+}
+
+func (mc *ManualClock) stopTimer(mt *manualTimer) bool {
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
+
+ if mt.mu.firesAt.IsZero() {
+ return false
+ }
+
+ mc.stopTimerLocked(mt)
+ return true
}
type manualTimer struct {
clock *ManualClock
- timer clockwork.Timer
+ f func()
- mu sync.RWMutex
- until time.Time
+ mu struct {
+ sync.Mutex
+
+ // firesAt is the time when the timer will fire.
+ //
+ // Zero only when the timer is not active.
+ firesAt time.Time
+ }
}
var _ tcpip.Timer = (*manualTimer)(nil)
// Reset implements tcpip.Timer.Reset.
-func (t *manualTimer) Reset(d time.Duration) {
- if !t.timer.Reset(d) {
- return
- }
-
- t.mu.Lock()
- defer t.mu.Unlock()
-
- t.clock.removeWait(t.until)
- t.until = t.clock.clock.Now().Add(d)
- t.clock.addWait(t.until)
+func (mt *manualTimer) Reset(d time.Duration) {
+ mt.clock.resetTimer(mt, d)
}
// Stop implements tcpip.Timer.Stop.
-func (t *manualTimer) Stop() bool {
- if !t.timer.Stop() {
- return false
- }
-
- t.mu.RLock()
- defer t.mu.RUnlock()
-
- t.clock.removeWait(t.until)
- return true
+func (mt *manualTimer) Stop() bool {
+ return mt.clock.stopTimer(mt)
}
type timeHeap []time.Time
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e6103f4bc..48ca60319 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@ package header
import (
"encoding/binary"
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -481,15 +480,13 @@ const (
IPv4OptionLengthOffset = 1
)
-// Potential errors when parsing generic IP options.
-var (
- ErrIPv4OptZeroLength = errors.New("zero length IP option")
- ErrIPv4OptDuplicate = errors.New("duplicate IP option")
- ErrIPv4OptInvalid = errors.New("invalid IP option")
- ErrIPv4OptMalformed = errors.New("malformed IP option")
- ErrIPv4OptionTruncated = errors.New("truncated IP option")
- ErrIPv4OptionAddress = errors.New("bad IP option address")
-)
+// IPv4OptParameterProblem indicates that a Parameter Problem message
+// should be generated, and gives the offset in the current entity
+// that should be used in that packet.
+type IPv4OptParameterProblem struct {
+ Pointer uint8
+ NeedICMP bool
+}
// IPv4Option is an interface representing various option types.
type IPv4Option interface {
@@ -583,8 +580,9 @@ func (i *IPv4OptionIterator) Finalize() IPv4Options {
// It returns
// - A slice of bytes holding the next option or nil if there is error.
// - A boolean which is true if parsing of all the options is complete.
-// - An error which is non-nil if an error condition was encountered.
-func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
+// Undefined in the case of error.
+// - An error indication which is non-nil if an error condition was found.
+func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) {
// The opts slice gets shorter as we process the options. When we have no
// bytes left we are done.
if len(i.options) == 0 {
@@ -606,24 +604,22 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
// There are no more single byte options defined. All the rest have a length
// field so we need to sanity check it.
if len(i.options) == 1 {
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
optLen := i.options[IPv4OptionLengthOffset]
- if optLen == 0 {
- i.ErrCursor++
- return nil, true, ErrIPv4OptZeroLength
- }
+ if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) {
+ // The actual error is in the length (2nd byte of the option) but we
+ // return the start of the option for compatibility with Linux.
- if optLen == 1 {
- i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
- }
-
- if optLen > uint8(len(i.options)) {
- i.ErrCursor++
- return nil, true, ErrIPv4OptionTruncated
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
optionBody := i.options[:optLen]
@@ -635,7 +631,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
case IPv4OptionTimestampType:
if optLen < IPv4OptionTimestampHdrLength {
i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
retval := IPv4OptionTimestamp(optionBody)
return &retval, false, nil
@@ -643,7 +642,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
case IPv4OptionRecordRouteType:
if optLen < IPv4OptionRecordRouteHdrLength {
i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
retval := IPv4OptionRecordRoute(optionBody)
return &retval, false, nil
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 5580d6a78..f2403978c 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -453,9 +453,9 @@ const (
)
// ScopeForIPv6Address returns the scope for an IPv6 address.
-func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
if len(addr) != IPv6AddressSize {
- return GlobalScope, tcpip.ErrBadAddress
+ return GlobalScope, &tcpip.ErrBadAddress{}
}
switch {
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index e3fbd64f3..f10f446a6 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -299,7 +299,7 @@ func TestScopeForIPv6Address(t *testing.T) {
name string
addr tcpip.Address
scope header.IPv6AddressScope
- err *tcpip.Error
+ err tcpip.Error
}{
{
name: "Unique Local",
@@ -329,15 +329,15 @@ func TestScopeForIPv6Address(t *testing.T) {
name: "IPv4",
addr: "\x01\x02\x03\x04",
scope: header.GlobalScope,
- err: tcpip.ErrBadAddress,
+ err: &tcpip.ErrBadAddress{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := header.ScopeForIPv6Address(test.addr)
- if err != test.err {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ if diff := cmp.Diff(test.err, err); diff != "" {
+ t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff)
}
if got != test.scope {
t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a068d93a4..cd76272de 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
@@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 2f2d9d4ac..d873766a6 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 10072eac1..ae1394ebf 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -35,7 +35,6 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
"//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f86c383d8..0164d851b 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -57,7 +57,7 @@ import (
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
- dispatch() (bool, *tcpip.Error)
+ dispatch() (bool, tcpip.Error)
}
// PacketDispatchMode are the various supported methods of receiving and
@@ -118,7 +118,7 @@ type endpoint struct {
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
- closed func(*tcpip.Error)
+ closed func(tcpip.Error)
inboundDispatchers []linkDispatcher
dispatcher stack.NetworkDispatcher
@@ -149,7 +149,7 @@ type Options struct {
// ClosedFunc is a function to be called when an endpoint's peer (if
// any) closes its end of the communication pipe.
- ClosedFunc func(*tcpip.Error)
+ ClosedFunc func(tcpip.Error)
// Address is the link address for this endpoint. Only used if
// EthernetHeader is true.
@@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if e.hdrSize > 0 {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
@@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
@@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// - pkt.EgressRoute
// - pkt.GSOOptions
// - pkt.NetworkProtocolNumber
-func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
// Preallocate to avoid repeated reallocation as we append to batch.
// batchSz is 47 because when SWGSO is in use then a single 65KB TCP
// segment can get split into 46 segments of 1420 bytes and a single 216
@@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
}
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
-func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error {
+func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
for {
cont, err := inboundDispatcher.dispatch()
if err != nil || !cont {
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 90da22d34..e82371798 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -96,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context {
}
done := make(chan struct{}, 2)
- opt.ClosedFunc = func(*tcpip.Error) {
+ opt.ClosedFunc = func(tcpip.Error) {
done <- struct{}{}
}
@@ -465,67 +464,85 @@ var capLengthTestCases = []struct {
config: []int{1, 2, 3},
n: 3,
wantUsed: 2,
- wantLengths: []int{1, 2, 3},
+ wantLengths: []int{1, 2},
},
}
-func TestReadVDispatcherCapLength(t *testing.T) {
+func TestIovecBuffer(t *testing.T) {
for _, c := range capLengthTestCases {
- // fd does not matter for this test.
- d := readVDispatcher{fd: -1, e: &endpoint{}}
- d.views = make([]buffer.View, len(c.config))
- d.iovecs = make([]syscall.Iovec, len(c.config))
- d.allocateViews(c.config)
-
- used := d.capViews(c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views))
- for i, v := range d.views {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
- }
-}
+ t.Run(c.comment, func(t *testing.T) {
+ b := newIovecBuffer(c.config, false /* skipsVnetHdr */)
-func TestRecvMMsgDispatcherCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- d := recvMMsgDispatcher{
- fd: -1, // fd does not matter for this test.
- e: &endpoint{},
- views: make([][]buffer.View, 1),
- iovecs: make([][]syscall.Iovec, 1),
- msgHdrs: make([]rawfile.MMsgHdr, 1),
- }
+ // Test initial allocation.
+ iovecs := b.nextIovecs()
+ if got, want := len(iovecs), len(c.config); got != want {
+ t.Fatalf("len(iovecs) = %d, want %d", got, want)
+ }
- for i := range d.views {
- d.views[i] = make([]buffer.View, len(c.config))
- }
- for i := range d.iovecs {
- d.iovecs[i] = make([]syscall.Iovec, len(c.config))
- }
- for k, msgHdr := range d.msgHdrs {
- msgHdr.Msg.Iov = &d.iovecs[k][0]
- msgHdr.Msg.Iovlen = uint64(len(c.config))
- }
+ // Make a copy as iovecs points to internal slice. We will need this state
+ // later.
+ oldIovecs := append([]syscall.Iovec(nil), iovecs...)
- d.allocateViews(c.config)
+ // Test the views that get pulled.
+ vv := b.pullViews(c.n)
+ var lengths []int
+ for _, v := range vv.Views() {
+ lengths = append(lengths, len(v))
+ }
+ if !reflect.DeepEqual(lengths, c.wantLengths) {
+ t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths)
+ }
- used := d.capViews(0, c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views[0]))
- for i, v := range d.views[0] {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
+ // Test that new views get reallocated.
+ for i, newIov := range b.nextIovecs() {
+ if i < c.wantUsed {
+ if newIov.Base == oldIovecs[i].Base {
+ t.Errorf("b.views[%d] should have been reallocated", i)
+ }
+ } else {
+ if newIov.Base != oldIovecs[i].Base {
+ t.Errorf("b.views[%d] should not have been reallocated", i)
+ }
+ }
+ }
+ })
+ }
+}
+func TestIovecBufferSkipVnetHdr(t *testing.T) {
+ for _, test := range []struct {
+ desc string
+ readN int
+ wantLen int
+ }{
+ {
+ desc: "nothing read",
+ readN: 0,
+ wantLen: 0,
+ },
+ {
+ desc: "smaller than vnet header",
+ readN: virtioNetHdrSize - 1,
+ wantLen: 0,
+ },
+ {
+ desc: "header skipped",
+ readN: virtioNetHdrSize + 100,
+ wantLen: 100,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ b := newIovecBuffer([]int{10, 20, 50, 50}, true)
+ // Pretend a read happend.
+ b.nextIovecs()
+ vv := b.pullViews(test.readN)
+ if got, want := vv.Size(), test.wantLen; got != want {
+ t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want)
+ }
+ if got, want := len(vv.ToOwnedView()), test.wantLen; got != want {
+ t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index c475dda20..a2b63fe6b 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -129,7 +129,7 @@ type packetMMapDispatcher struct {
ringOffset int
}
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
for hdr.tpStatus()&tpStatusUser == 0 {
event := rawfile.PollEvent{
@@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
// dispatch reads packets from an mmaped ring buffer and dispatches them to the
// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
pkt, err := d.readMMappedPacket()
if err != nil {
return false, err
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 8c3ca86d6..ecae1ad2d 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -29,92 +29,124 @@ import (
// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
-// readVDispatcher uses readv() system call to read inbound packets and
-// dispatches them.
-type readVDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
+type iovecBuffer struct {
// views are the actual buffers that hold the packet contents.
views []buffer.View
// iovecs are initialized with base pointers/len of the corresponding
- // entries in the views defined above, except when GSO is enabled then
- // the first iovec points to a buffer for the vnet header which is
- // stripped before the views are passed up the stack for further
+ // entries in the views defined above, except when GSO is enabled
+ // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header
+ // which is stripped before the views are passed up the stack for further
// processing.
iovecs []syscall.Iovec
+
+ // sizes is an array of buffer sizes for the underlying views. sizes is
+ // immutable.
+ sizes []int
+
+ // skipsVnetHdr is true if virtioNetHdr is to skipped.
+ skipsVnetHdr bool
}
-func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- d := &readVDispatcher{fd: fd, e: e}
- d.views = make([]buffer.View, len(BufConfig))
- iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- iovLen++
+func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer {
+ b := &iovecBuffer{
+ views: make([]buffer.View, len(sizes)),
+ sizes: sizes,
+ skipsVnetHdr: skipsVnetHdr,
}
- d.iovecs = make([]syscall.Iovec, iovLen)
- return d, nil
+ niov := len(b.views)
+ if b.skipsVnetHdr {
+ niov++
+ }
+ b.iovecs = make([]syscall.Iovec, niov)
+ return b
}
-func (d *readVDispatcher) allocateViews(bufConfig []int) {
- var vnetHdr [virtioNetHdrSize]byte
+func (b *iovecBuffer) nextIovecs() []syscall.Iovec {
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if b.skipsVnetHdr {
+ var vnetHdr [virtioNetHdrSize]byte
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
- d.iovecs[0] = syscall.Iovec{
+ b.iovecs[0] = syscall.Iovec{
Base: &vnetHdr[0],
Len: uint64(virtioNetHdrSize),
}
vnetHdrOff++
}
- for i := 0; i < len(bufConfig); i++ {
- if d.views[i] != nil {
+ for i := range b.views {
+ if b.views[i] != nil {
break
}
- b := buffer.NewView(bufConfig[i])
- d.views[i] = b
- d.iovecs[i+vnetHdrOff] = syscall.Iovec{
- Base: &b[0],
- Len: uint64(len(b)),
+ v := buffer.NewView(b.sizes[i])
+ b.views[i] = v
+ b.iovecs[i+vnetHdrOff] = syscall.Iovec{
+ Base: &v[0],
+ Len: uint64(len(v)),
}
}
+ return b.iovecs
}
-func (d *readVDispatcher) capViews(n int, buffers []int) int {
+func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView {
+ var views []buffer.View
c := 0
- for i, s := range buffers {
- c += s
+ if b.skipsVnetHdr {
+ c += virtioNetHdrSize
if c >= n {
- d.views[i].CapLength(s - (c - n))
- return i + 1
+ // Nothing in the packet.
+ return buffer.NewVectorisedView(0, nil)
+ }
+ }
+ for i, v := range b.views {
+ c += len(v)
+ if c >= n {
+ b.views[i].CapLength(len(v) - (c - n))
+ views = append([]buffer.View(nil), b.views[:i+1]...)
+ break
}
}
- return len(buffers)
+ // Remove the first len(views) used views from the state.
+ for i := range views {
+ b.views[i] = nil
+ }
+ if b.skipsVnetHdr {
+ // Exclude the size of the vnet header.
+ n -= virtioNetHdrSize
+ }
+ return buffer.NewVectorisedView(n, views)
}
-// dispatch reads one packet from the file descriptor and dispatches it.
-func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
- d.allocateViews(BufConfig)
+// readVDispatcher uses readv() system call to read inbound packets and
+// dispatches them.
+type readVDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // buf is the iovec buffer that contains the packet contents.
+ buf *iovecBuffer
+}
+
+func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &readVDispatcher{fd: fd, e: e}
+ skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
+ return d, nil
+}
- n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
+ n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs())
if n == 0 || err != nil {
return false, err
}
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // Skip virtioNetHdr which is added before each packet, it
- // isn't used and it isn't in a view.
- n -= virtioNetHdrSize
- }
- used := d.capViews(n, BufConfig)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ Data: d.buf.pullViews(n),
})
var (
@@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
} else {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(d.views[0]) {
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data.PullUp(1)
+ if !ok {
+ return true, nil
+ }
+ switch header.IPVersion(h) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
@@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
- // Prepare e.views for another packet: release used views.
- for i := 0; i < used; i++ {
- d.views[i] = nil
- }
-
return true, nil
}
@@ -162,15 +194,8 @@ type recvMMsgDispatcher struct {
// e is the endpoint this dispatcher is attached to.
e *endpoint
- // views is an array of array of buffers that contain packet contents.
- views [][]buffer.View
-
- // iovecs is an array of array of iovec records where each iovec base
- // pointer and length are initialzed to the corresponding view above,
- // except when GSO is enabled then the first iovec in each array of
- // iovecs points to a buffer for the vnet header which is stripped
- // before the views are passed up the stack for further processing.
- iovecs [][]syscall.Iovec
+ // bufs is an array of iovec buffers that contain packet contents.
+ bufs []*iovecBuffer
// msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
// reference an array of iovecs in the iovecs field defined above. This
@@ -187,74 +212,32 @@ const (
func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &recvMMsgDispatcher{
- fd: fd,
- e: e,
- }
- d.views = make([][]buffer.View, MaxMsgsPerRecv)
- for i := range d.views {
- d.views[i] = make([]buffer.View, len(BufConfig))
- }
- d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
- iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // virtioNetHdr is prepended before each packet.
- iovLen++
+ fd: fd,
+ e: e,
+ bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
+ msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
}
- for i := range d.iovecs {
- d.iovecs[i] = make([]syscall.Iovec, iovLen)
- }
- d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv)
- for i := range d.msgHdrs {
- d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0]
- d.msgHdrs[i].Msg.Iovlen = uint64(iovLen)
+ skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ for i := range d.bufs {
+ d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
}
return d, nil
}
-func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int {
- c := 0
- for i, s := range buffers {
- c += s
- if c >= n {
- d.views[k][i].CapLength(s - (c - n))
- return i + 1
- }
- }
- return len(buffers)
-}
-
-func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
- for k := 0; k < len(d.views); k++ {
- var vnetHdr [virtioNetHdrSize]byte
- vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // The kernel adds virtioNetHdr before each packet, but
- // we don't use it, so so we allocate a buffer for it,
- // add it in iovecs but don't add it in a view.
- d.iovecs[k][0] = syscall.Iovec{
- Base: &vnetHdr[0],
- Len: uint64(virtioNetHdrSize),
- }
- vnetHdrOff++
- }
- for i := 0; i < len(bufConfig); i++ {
- if d.views[k][i] != nil {
- break
- }
- b := buffer.NewView(bufConfig[i])
- d.views[k][i] = b
- d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{
- Base: &b[0],
- Len: uint64(len(b)),
- }
- }
- }
-}
-
// recvMMsgDispatch reads more than one packet at a time from the file
// descriptor and dispatches it.
-func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
- d.allocateViews(BufConfig)
+func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
+ // Fill message headers.
+ for k := range d.msgHdrs {
+ if d.msgHdrs[k].Msg.Iovlen > 0 {
+ break
+ }
+ iovecs := d.bufs[k].nextIovecs()
+ iovLen := len(iovecs)
+ d.msgHdrs[k].Len = 0
+ d.msgHdrs[k].Msg.Iov = &iovecs[0]
+ d.msgHdrs[k].Msg.Iovlen = uint64(iovLen)
+ }
nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
if err != nil {
@@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
// Process each of received packets.
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].Len)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- n -= virtioNetHdrSize
- }
- used := d.capViews(k, int(n), BufConfig)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ Data: d.bufs[k].pullViews(n),
})
+ // Mark that this iovec has been processed.
+ d.msgHdrs[k].Msg.Iovlen = 0
+
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
@@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
} else {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(d.views[k][0]) {
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data.PullUp(1)
+ if !ok {
+ // Skip this packet.
+ continue
+ }
+ switch header.IPVersion(h) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
p = header.IPv6ProtocolNumber
default:
- return true, nil
+ // Skip this packet.
+ continue
}
}
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
-
- // Prepare e.views for another packet: release used views.
- for i := 0; i < used; i++ {
- d.views[k][i] = nil
- }
- }
-
- for k := 0; k < nMsgs; k++ {
- d.msgHdrs[k].Len = 0
}
return true, nil
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ac6a6be87..691467870 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// Construct data as the unparsed portion for the loopback packet.
data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 316f508e6..668f72eee 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber,
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
return endpoint.WritePackets(r, gso, pkts, protocol)
}
@@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
-func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error {
endpoint, ok := m.routes[dest]
if !ok {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
return endpoint.InjectOutbound(dest, packet)
}
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 814a54f23..97ad9fdd5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return e.child.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return e.child.WritePackets(r, gso, pkts, protocol)
}
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index c95cdd681..6cbe18a56 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index d6e83a414..bbe84f220 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -45,12 +45,7 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
}
-// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- if !e.linked.IsAttached() {
- return nil
- }
-
+func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) {
// Note that the local address from the perspective of this endpoint is the
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
@@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw
//
// TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send
// and receive queues.
- go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
+ go func() {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+ }
+ }()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.linked.IsAttached() {
+ var pkts stack.PacketBufferList
+ pkts.PushBack(pkt)
+ e.deliverPackets(r, proto, pkts)
+ }
return nil
}
// WritePackets implements stack.LinkEndpoint.
-func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- panic("not implemented")
+func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ if e.linked.IsAttached() {
+ e.deliverPackets(r, proto, pkts)
+ }
+
+ return pkts.Len(), nil
}
// Attach implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 87035b034..128ef6e87 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
pkt.EgressRoute = r
@@ -158,26 +158,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
if !d.q.enqueue(pkt) {
- return tcpip.ErrNoBufferSpace
+ return &tcpip.ErrNoBufferSpace{}
}
d.newPacketWaker.Assert()
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+//
+// Being a batch API, each packet in pkts should have the following
+// fields populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
enqueued := 0
for pkt := pkts.Front(); pkt != nil; {
- pkt.EgressRoute = r
- pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
}
- return enqueued, tcpip.ErrNoBufferSpace
+ return enqueued, &tcpip.ErrNoBufferSpace{}
}
pkt = nxt
enqueued++
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6c410c5a6..e1047da50 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -27,5 +27,6 @@ go_test(
library = "rawfile",
deps = [
"//pkg/tcpip",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 604868fd8..406b97709 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -17,7 +17,6 @@
package rawfile
import (
- "fmt"
"syscall"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -25,48 +24,54 @@ import (
const maxErrno = 134
-var translations [maxErrno]*tcpip.Error
-
// TranslateErrno translate an errno from the syscall package into a
-// *tcpip.Error.
+// tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL).
-func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if e > 0 && e < syscall.Errno(len(translations)) {
- if err := translations[e]; err != nil {
- return err
- }
- }
- return tcpip.ErrInvalidEndpointState
-}
-
-func addTranslation(host syscall.Errno, trans *tcpip.Error) {
- if translations[host] != nil {
- panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+// *tcpip.ErrInvalidEndpointState (EINVAL).
+func TranslateErrno(e syscall.Errno) tcpip.Error {
+ switch e {
+ case syscall.EEXIST:
+ return &tcpip.ErrDuplicateAddress{}
+ case syscall.ENETUNREACH:
+ return &tcpip.ErrNoRoute{}
+ case syscall.EINVAL:
+ return &tcpip.ErrInvalidEndpointState{}
+ case syscall.EALREADY:
+ return &tcpip.ErrAlreadyConnecting{}
+ case syscall.EISCONN:
+ return &tcpip.ErrAlreadyConnected{}
+ case syscall.EADDRINUSE:
+ return &tcpip.ErrPortInUse{}
+ case syscall.EADDRNOTAVAIL:
+ return &tcpip.ErrBadLocalAddress{}
+ case syscall.EPIPE:
+ return &tcpip.ErrClosedForSend{}
+ case syscall.EWOULDBLOCK:
+ return &tcpip.ErrWouldBlock{}
+ case syscall.ECONNREFUSED:
+ return &tcpip.ErrConnectionRefused{}
+ case syscall.ETIMEDOUT:
+ return &tcpip.ErrTimeout{}
+ case syscall.EINPROGRESS:
+ return &tcpip.ErrConnectStarted{}
+ case syscall.EDESTADDRREQ:
+ return &tcpip.ErrDestinationRequired{}
+ case syscall.ENOTSUP:
+ return &tcpip.ErrNotSupported{}
+ case syscall.ENOTTY:
+ return &tcpip.ErrQueueSizeNotSupported{}
+ case syscall.ENOTCONN:
+ return &tcpip.ErrNotConnected{}
+ case syscall.ECONNRESET:
+ return &tcpip.ErrConnectionReset{}
+ case syscall.ECONNABORTED:
+ return &tcpip.ErrConnectionAborted{}
+ case syscall.EMSGSIZE:
+ return &tcpip.ErrMessageTooLong{}
+ case syscall.ENOBUFS:
+ return &tcpip.ErrNoBufferSpace{}
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
}
- translations[host] = trans
-}
-
-func init() {
- addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
- addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
- addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
- addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
- addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
- addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
- addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
- addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
- addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
- addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
- addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
- addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
- addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
- addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
- addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
- addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
- addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
- addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
- addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
- addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
index e4cdc66bd..61aea1744 100644
--- a/pkg/tcpip/link/rawfile/errors_test.go
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -20,34 +20,35 @@ import (
"syscall"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
func TestTranslateErrno(t *testing.T) {
for _, test := range []struct {
errno syscall.Errno
- translated *tcpip.Error
+ translated tcpip.Error
}{
{
errno: syscall.Errno(0),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.Errno(maxErrno),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.Errno(514),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.EEXIST,
- translated: tcpip.ErrDuplicateAddress,
+ translated: &tcpip.ErrDuplicateAddress{},
},
} {
got := TranslateErrno(test.errno)
- if got != test.translated {
- t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ if diff := cmp.Diff(test.translated, got); diff != "" {
+ t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff)
}
}
}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index f4c32c2da..06f3ee21e 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) {
// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
// partial data is written.
-func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+func NonBlockingWrite(fd int, buf []byte) tcpip.Error {
var ptr unsafe.Pointer
if len(buf) > 0 {
ptr = unsafe.Pointer(&buf[0])
@@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
// It fails if partial data is written.
-func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error {
iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
@@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
}
// NonBlockingSendMMsg sends multiple messages on a socket.
-func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
if e != 0 {
return 0, TranslateErrno(e)
@@ -97,7 +97,7 @@ type PollEvent struct {
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
-func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
@@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
// BlockingReadv reads from a file descriptor that is set up as non-blocking and
// stores the data in a list of iovecs buffers. If no data is available, it will
// block in a poll() syscall until the file descriptor becomes readable.
-func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
@@ -149,7 +149,7 @@ type MMsgHdr struct {
// and stores the received messages in a slice of MMsgHdr structures. If no data
// is available, it will block in a poll() syscall until the file descriptor
// becomes readable.
-func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
if e == 0 {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 6c937c858..2599bc406 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
views := pkt.Views()
@@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
e.mu.Unlock()
if !ok {
- return tcpip.ErrWouldBlock
+ return &tcpip.ErrWouldBlock{}
}
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 23242b9e0..d480ad656 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) {
Data: buf.ToVectorisedView(),
})
err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
- if want := tcpip.ErrWouldBlock; err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buffer.NewView(bufferSize).ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 5859851d8..bd2b8d4bf 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.dumpPacket(directionSend, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index bfac358f4..3829ca9c9 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
Name: endpoint.name,
})
- switch err {
+ switch err.(type) {
case nil:
return endpoint, nil
- case tcpip.ErrDuplicateNICID:
+ case *tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 30f1ad540..20259b285 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
@@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
if !e.writeGate.Enter() {
return pkts.Len(), nil
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index b139de7dd..e368a9eaa 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
e.writeCount += pkts.Len()
return pkts.Len(), nil
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9ebf31b78..0caa65251 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -25,5 +25,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 8a6bcfc2c..c7ab876bf 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -4,9 +4,13 @@ package(licenses = ["notice"])
go_library(
name = "arp",
- srcs = ["arp.go"],
+ srcs = [
+ "arp.go",
+ "stats.go",
+ ],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -33,3 +37,15 @@ go_test(
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
+
+go_test(
+ name = "stats_test",
+ size = "small",
+ srcs = ["stats_test.go"],
+ library = ":arp",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 3259d052f..7838cc753 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -19,8 +19,10 @@ package arp
import (
"fmt"
+ "reflect"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -50,11 +52,12 @@ type endpoint struct {
nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
+ stats sharedStats
}
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
e.setEnabled(true)
@@ -98,10 +101,12 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.protocol.forgetEndpoint(e.nic.ID())
+}
-func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -110,51 +115,45 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
+ return 0, &tcpip.ErrNotSupported{}
}
-func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats().ARP
- stats.PacketsReceived.Increment()
+ stats := e.stats.arp
+ stats.packetsReceived.Increment()
if !e.isEnabled() {
- stats.DisabledPacketsReceived.Increment()
+ stats.disabledPacketsReceived.Increment()
return
}
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
- stats.MalformedPacketsReceived.Increment()
+ stats.malformedPacketsReceived.Increment()
return
}
switch h.Op() {
case header.ARPRequest:
- stats.RequestsReceived.Increment()
+ stats.requestsReceived.Increment()
localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ stats.requestsReceivedUnknownTargetAddress.Increment()
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.RequestsReceivedUnknownTargetAddress.Increment()
- return // we have no useful answer, ignore the request
- }
-
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.RequestsReceivedUnknownTargetAddress.Increment()
- return // we have no useful answer, ignore the request
- }
-
- remoteAddr := tcpip.Address(h.ProtocolAddressSender())
- remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
@@ -186,18 +185,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Send the packet to the (new) target hardware address on the same
// hardware on which the request was received.
if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil {
- stats.OutgoingRepliesDropped.Increment()
+ stats.outgoingRepliesDropped.Increment()
} else {
- stats.OutgoingRepliesSent.Increment()
+ stats.outgoingRepliesSent.Increment()
}
case header.ARPReply:
- stats.RepliesReceived.Increment()
+ stats.repliesReceived.Increment()
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(addr, linkAddr)
return
}
@@ -216,9 +215,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
+var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ // eps is keyed by NICID to allow protocol methods to retrieve the correct
+ // endpoint depending on the NIC.
+ eps map[tcpip.NICID]*endpoint
+ }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -236,39 +251,62 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
linkAddrCache: linkAddrCache,
nud: nud,
}
+
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+
+ stackStats := p.stack.Stats()
+ e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
+
+ p.mu.Lock()
+ p.mu.eps[nic.ID()] = e
+ p.mu.Unlock()
+
return e
}
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, nicID)
+}
+
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
- stats := p.stack.Stats().ARP
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
+ nicID := nic.ID()
+
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[nicID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ stats := netEP.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
- nicID := nic.ID()
if len(localAddr) == 0 {
- addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
- if err != nil {
- stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
- return err
+ addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
}
if len(addr.Address) == 0 {
- stats.OutgoingRequestNetworkUnreachableErrors.Increment()
- return tcpip.ErrNetworkUnreachable
+ stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
+ return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addr.Address
} else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.OutgoingRequestBadLocalAddressErrors.Increment()
- return tcpip.ErrBadLocalAddress
+ stats.outgoingRequestBadLocalAddressErrors.Increment()
+ return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -288,10 +326,10 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
- stats.OutgoingRequestsDropped.Increment()
+ stats.outgoingRequestsDropped.Increment()
return err
}
- stats.OutgoingRequestsSent.Increment()
+ stats.outgoingRequestsSent.Increment()
return nil
}
@@ -307,13 +345,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
@@ -329,5 +367,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
// NewProtocol returns an ARP network protocol.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
- return &protocol{stack: s}
+ return &protocol{
+ stack: s,
+ mu: struct {
+ sync.RWMutex
+ eps map[tcpip.NICID]*endpoint
+ }{eps: make(map[tcpip.NICID]*endpoint)},
+ }
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 0536e1698..b0f07aa44 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
}
case <-ctx.Done():
return fmt.Errorf("%s for %s", ctx.Err(), want)
@@ -537,7 +537,7 @@ type testInterface struct {
nicID tcpip.NICID
- writeErr *tcpip.Error
+ writeErr tcpip.Error
}
func (t *testInterface) ID() tcpip.NICID {
@@ -560,15 +560,15 @@ func (*testInterface) Promiscuous() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
@@ -585,104 +585,122 @@ func TestLinkAddressRequest(t *testing.T) {
testAddr := tcpip.Address([]byte{1, 2, 3, 4})
tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
-
- linkErr *tcpip.Error
- expectedErr *tcpip.Error
- expectedLocalAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- expectedRequestsSent uint64
- expectedRequestBadLocalAddressErrors uint64
- expectedRequestNetworkUnreachableErrors uint64
- expectedRequestDroppedErrors uint64
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+ linkErr tcpip.Error
+ expectedErr tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
+ expectedRequestsSent uint64
+ expectedRequestBadLocalAddressErrors uint64
+ expectedRequestInterfaceHasNoLocalAddressErrors uint64
+ expectedRequestDroppedErrors uint64
}{
{
- name: "Unicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ localAddr: "",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ localAddr: "",
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrBadLocalAddress,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast with unassigned address",
+ nicAddr: stackAddr,
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrBadLocalAddress,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast with unassigned address",
+ nicAddr: stackAddr,
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: &tcpip.ErrBadLocalAddress{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with no local address available",
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrNetworkUnreachable,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 1,
+ name: "Unicast with no local address available",
+ nicAddr: "",
+ localAddr: "",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 1,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with no local address available",
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 1,
+ name: "Multicast with no local address available",
+ nicAddr: "",
+ localAddr: "",
+ remoteLinkAddr: "",
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 1,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Link error",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- linkErr: tcpip.ErrInvalidEndpointState,
- expectedErr: tcpip.ErrInvalidEndpointState,
- expectedRequestDroppedErrors: 1,
+ name: "Link error",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ linkErr: &tcpip.ErrInvalidEndpointState{},
+ expectedErr: &tcpip.ErrInvalidEndpointState{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 1,
},
}
@@ -714,19 +732,20 @@ func TestLinkAddressRequest(t *testing.T) {
// link endpoint even though the stack uses the real NIC to validate the
// local address.
iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent)
}
+ if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors)
+ }
if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors {
t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors)
}
- if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors)
- }
if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors {
t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors)
}
@@ -774,11 +793,8 @@ func TestLinkAddressRequestWithoutNIC(t *testing.T) {
t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
}
- if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID)
- }
-
- if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 {
- t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got)
+ err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID})
+ if _, ok := err.(*tcpip.ErrNotConnected); !ok {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{})
}
}
diff --git a/pkg/tcpip/network/arp/stats.go b/pkg/tcpip/network/arp/stats.go
new file mode 100644
index 000000000..6d7194c6c
--- /dev/null
+++ b/pkg/tcpip/network/arp/stats.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to ARP.
+type Stats struct {
+ // ARP holds ARP statistics.
+ ARP tcpip.ARPStats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+type sharedStats struct {
+ localStats Stats
+ arp multiCounterARPStats
+}
+
+// LINT.IfChange(multiCounterARPStats)
+
+type multiCounterARPStats struct {
+ packetsReceived tcpip.MultiCounterStat
+ disabledPacketsReceived tcpip.MultiCounterStat
+ malformedPacketsReceived tcpip.MultiCounterStat
+ requestsReceived tcpip.MultiCounterStat
+ requestsReceivedUnknownTargetAddress tcpip.MultiCounterStat
+ outgoingRequestInterfaceHasNoLocalAddressErrors tcpip.MultiCounterStat
+ outgoingRequestBadLocalAddressErrors tcpip.MultiCounterStat
+ outgoingRequestsDropped tcpip.MultiCounterStat
+ outgoingRequestsSent tcpip.MultiCounterStat
+ repliesReceived tcpip.MultiCounterStat
+ outgoingRepliesDropped tcpip.MultiCounterStat
+ outgoingRepliesSent tcpip.MultiCounterStat
+}
+
+func (m *multiCounterARPStats) init(a, b *tcpip.ARPStats) {
+ m.packetsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.disabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
+ m.malformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
+ m.requestsReceived.Init(a.RequestsReceived, b.RequestsReceived)
+ m.requestsReceivedUnknownTargetAddress.Init(a.RequestsReceivedUnknownTargetAddress, b.RequestsReceivedUnknownTargetAddress)
+ m.outgoingRequestInterfaceHasNoLocalAddressErrors.Init(a.OutgoingRequestInterfaceHasNoLocalAddressErrors, b.OutgoingRequestInterfaceHasNoLocalAddressErrors)
+ m.outgoingRequestBadLocalAddressErrors.Init(a.OutgoingRequestBadLocalAddressErrors, b.OutgoingRequestBadLocalAddressErrors)
+ m.outgoingRequestsDropped.Init(a.OutgoingRequestsDropped, b.OutgoingRequestsDropped)
+ m.outgoingRequestsSent.Init(a.OutgoingRequestsSent, b.OutgoingRequestsSent)
+ m.repliesReceived.Init(a.RepliesReceived, b.RepliesReceived)
+ m.outgoingRepliesDropped.Init(a.OutgoingRepliesDropped, b.OutgoingRepliesDropped)
+ m.outgoingRepliesSent.Init(a.OutgoingRepliesSent, b.OutgoingRepliesSent)
+}
+
+// LINT.ThenChange(../../tcpip.go:ARPStats)
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
new file mode 100644
index 000000000..036fdf739
--- /dev/null
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -0,0 +1,93 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkInterface
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ nic := testInterface{nicID: 1}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
+ }
+
+ ep.Close()
+
+ proto.mu.Lock()
+ _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEndpointAfterClose {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // expected to be bound by a MultiCounterStat.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
index ca1247c1e..411bca25d 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/ip/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "ip",
- srcs = ["generic_multicast_protocol.go"],
+ srcs = [
+ "generic_multicast_protocol.go",
+ "stats.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index f2f0e069c..a81f5c8c3 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -174,10 +174,10 @@ type MulticastGroupProtocol interface {
//
// Returns false if the caller should queue the report to be sent later. Note,
// returning false does not mean that the receiver hit an error.
- SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+ SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
- SendLeave(groupAddress tcpip.Address) *tcpip.Error
+ SendLeave(groupAddress tcpip.Address) tcpip.Error
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 85593f211..d5d5a449e 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
@@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go
new file mode 100644
index 000000000..898f8b356
--- /dev/null
+++ b/pkg/tcpip/network/ip/stats.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// LINT.IfChange(MultiCounterIPStats)
+
+// MultiCounterIPStats holds IP statistics, each counter may have several
+// versions.
+type MultiCounterIPStats struct {
+ // PacketsReceived is the total number of IP packets received from the link
+ // layer.
+ PacketsReceived tcpip.MultiCounterStat
+
+ // DisabledPacketsReceived is the total number of IP packets received from the
+ // link layer when the IP layer is disabled.
+ DisabledPacketsReceived tcpip.MultiCounterStat
+
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived tcpip.MultiCounterStat
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived tcpip.MultiCounterStat
+
+ // PacketsDelivered is the total number of incoming IP packets that are
+ // successfully delivered to the transport layer.
+ PacketsDelivered tcpip.MultiCounterStat
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent tcpip.MultiCounterStat
+
+ // OutgoingPacketErrors is the total number of IP packets which failed to
+ // write to a link-layer endpoint.
+ OutgoingPacketErrors tcpip.MultiCounterStat
+
+ // MalformedPacketsReceived is the total number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
+ MalformedPacketsReceived tcpip.MultiCounterStat
+
+ // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
+ MalformedFragmentsReceived tcpip.MultiCounterStat
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // Prerouting chain.
+ IPTablesPreroutingDropped tcpip.MultiCounterStat
+
+ // IPTablesInputDropped is the total number of IP packets dropped in the Input
+ // chain.
+ IPTablesInputDropped tcpip.MultiCounterStat
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in the
+ // Output chain.
+ IPTablesOutputDropped tcpip.MultiCounterStat
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived tcpip.MultiCounterStat
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived tcpip.MultiCounterStat
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived tcpip.MultiCounterStat
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
+ m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
+ m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
+ m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
+ m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered)
+ m.PacketsSent.Init(a.PacketsSent, b.PacketsSent)
+ m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors)
+ m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
+ m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
+ m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
+ m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived)
+ m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived)
+ m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+}
+
+// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 3005973d7..47cce79bb 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -18,6 +18,7 @@ import (
"strings"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -167,7 +168,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -189,7 +190,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
@@ -203,7 +204,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -219,7 +220,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -235,14 +236,14 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) {
t.Helper()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- e := channel.New(0, 1280, "")
+ e := channel.New(1, mtu, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -263,7 +264,7 @@ func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpo
func buildDummyStack(t *testing.T) *stack.Stack {
t.Helper()
- s, _ := buildDummyStackWithLinkEndpoint(t)
+ s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
return s
}
@@ -306,8 +307,8 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func TestSourceAddressValidation(t *testing.T) {
@@ -416,7 +417,7 @@ func TestSourceAddressValidation(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t)
+ s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
test.rxICMP(e, test.srcAddress)
var wantValid uint64
@@ -479,8 +480,9 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// Attempting to enable the endpoint while the NIC is disabled should
// fail.
nic.setEnabled(false)
- if err := ep.Enable(); err != tcpip.ErrNotPermitted {
- t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ err := ep.Enable()
+ if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
}
// ep should consider the NIC's enabled status when determining its own
// enabled status so we "enable" the NIC to read just the endpoint's
@@ -1122,7 +1124,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
}{
{
name: "IPv4",
@@ -1187,7 +1189,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 too small",
@@ -1205,7 +1207,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 minimum size",
@@ -1465,7 +1467,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
}
@@ -1490,7 +1492,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
})
- e := channel.New(1, 1280, "")
+ e := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -1506,10 +1508,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
defer r.Release()
- if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr),
- })); err != test.expectedErr {
- t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ {
+ err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr),
+ }))
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
+ }
}
if test.expectedErr != nil {
@@ -1526,3 +1531,246 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
}
}
+
+// Test that the included data in an ICMP error packet conforms to the
+// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
+func TestICMPInclusionSize(t *testing.T) {
+ const (
+ replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ targetSize4 = header.IPv4MinimumProcessableDatagramSize
+ targetSize6 = header.IPv6MinimumMTU
+ // A protocol number that will cause an error response.
+ reservedProtocol = 254
+ )
+
+ // IPv4 function to create a IP packet and send it to the stack.
+ // The packet should generate an error response. We can do that by using an
+ // unknown transport protocol (254).
+ rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(payload)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize)
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: reservedProtocol,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(buffer.View(payload))
+ // Take a copy before InjectInbound takes ownership of vv
+ // as vv may be changed during the call.
+ v := vv.ToView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ return v
+ }
+
+ // IPv6 function to create a packet and send it to the stack.
+ // The packet should be errant in a way that causes the stack to send an
+ // ICMP error response and have enough data to allow the testing of the
+ // inclusion of the errant packet. Use `unknown next header' to generate
+ // the error.
+ rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload)),
+ TransportProtocol: reservedProtocol,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(buffer.View(payload))
+ // Take a copy before InjectInbound takes ownership of vv
+ // as vv may be changed during the call.
+ v := vv.ToView()
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ return v
+ }
+
+ v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
+ // We already know the entire packet is the right size so we can use its
+ // length to calculate the right payload size to check.
+ expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
+ checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()),
+ checker.SrcAddr(localIPv4Addr),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
+ checker.ICMPv4Payload(payload[:expectedPayloadLength]),
+ ),
+ )
+ }
+
+ v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
+ // We already know the entire packet is the right size so we can use its
+ // length to calculate the right payload size to check.
+ expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
+ checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()),
+ checker.SrcAddr(localIPv6Addr),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6ParamProblem),
+ checker.ICMPv6Code(header.ICMPv6UnknownHeader),
+ checker.ICMPv6Payload(payload[:expectedPayloadLength]),
+ ),
+ )
+ }
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, buffer.View)
+ payloadLength int // Not including IP header.
+ linkMTU uint32 // Largest IP packet that the link can send as payload.
+ replyLength int // Total size of IP/ICMP packet expected back.
+ }{
+ {
+ name: "IPv4 exact match",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 - replyHeaderLength4,
+ linkMTU: targetSize4,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 larger MTU",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4,
+ linkMTU: targetSize4 + 1000,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 smaller MTU",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4,
+ linkMTU: targetSize4 - 50,
+ replyLength: targetSize4 - 50,
+ },
+ {
+ name: "IPv4 payload exceeds",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 + 10,
+ linkMTU: targetSize4,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 1 byte less",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 - replyHeaderLength4 - 1,
+ linkMTU: targetSize4,
+ replyLength: targetSize4 - 1,
+ },
+ {
+ name: "IPv4 No payload",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: 0,
+ linkMTU: targetSize4,
+ replyLength: replyHeaderLength4,
+ },
+ {
+ name: "IPv6 exact match",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6 - replyHeaderLength6,
+ linkMTU: targetSize6,
+ replyLength: targetSize6,
+ },
+ {
+ name: "IPv6 larger MTU",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6,
+ linkMTU: targetSize6 + 400,
+ replyLength: targetSize6,
+ },
+ // NB. No "smaller MTU" test here as less than 1280 is not permitted
+ // in IPv6.
+ {
+ name: "IPv6 payload exceeds",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6,
+ linkMTU: targetSize6,
+ replyLength: targetSize6,
+ },
+ {
+ name: "IPv6 1 byte less",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6 - replyHeaderLength6 - 1,
+ linkMTU: targetSize6,
+ replyLength: targetSize6 - 1,
+ },
+ {
+ name: "IPv6 no payload",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: 0,
+ linkMTU: targetSize6,
+ replyLength: replyHeaderLength6,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU)
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(test.payloadLength)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+ // Default routes for IPv4&6 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP error Reply.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+ v := test.injector(e, test.srcAddress, payload)
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ if got, want := pkt.Pkt.Size(), test.replyLength; got != want {
+ t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
+ }
+ test.checker(t, pkt.Pkt, v)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 32f53f217..330a7d170 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -8,6 +8,7 @@ go_library(
"icmp.go",
"igmp.go",
"ipv4.go",
+ "stats.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -49,3 +50,15 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "stats_test",
+ size = "small",
+ srcs = ["stats_test.go"],
+ library = ":ipv4",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 8e392f86c..3d93a2cd0 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
package ipv4
import (
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -62,21 +61,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4.PacketsReceived
+ received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
- received.Invalid.Increment()
+ received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
// we are the only handler, however other types do not cope well with
@@ -106,19 +104,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := e.processIPOptions(pkt, opts, op)
- if err != nil {
- switch {
- case
- errors.Is(err, header.ErrIPv4OptDuplicate),
- errors.Is(err, errIPv4RecordRouteOptInvalidLength),
- errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ tmp, optProblem := e.processIPOptions(pkt, opts, op)
+ if optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -128,11 +121,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
- received.Echo.Increment()
+ received.echo.Increment()
- sent := stats.ICMP.V4.PacketsSent
+ sent := e.stats.icmp.packetsSent
if !e.protocol.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return
}
@@ -213,18 +206,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.EchoReply.Increment()
+ sent.echoReply.Increment()
case header.ICMPv4EchoReply:
- received.EchoReply.Increment()
+ received.echoReply.Increment()
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
- received.DstUnreachable.Increment()
+ received.dstUnreachable.Increment()
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
@@ -243,31 +236,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
case header.ICMPv4SrcQuench:
- received.SrcQuench.Increment()
+ received.srcQuench.Increment()
case header.ICMPv4Redirect:
- received.Redirect.Increment()
+ received.redirect.Increment()
case header.ICMPv4TimeExceeded:
- received.TimeExceeded.Increment()
+ received.timeExceeded.Increment()
case header.ICMPv4ParamProblem:
- received.ParamProblem.Increment()
+ received.paramProblem.Increment()
case header.ICMPv4Timestamp:
- received.Timestamp.Increment()
+ received.timestamp.Increment()
case header.ICMPv4TimestampReply:
- received.TimestampReply.Increment()
+ received.timestampReply.Increment()
case header.ICMPv4InfoRequest:
- received.InfoRequest.Increment()
+ received.infoRequest.Increment()
case header.ICMPv4InfoReply:
- received.InfoReply.Increment()
+ received.infoReply.Increment()
default:
- received.Invalid.Increment()
+ received.invalid.Increment()
}
}
@@ -317,7 +310,7 @@ func (*icmpReasonParamProblem) isICMPReason() {}
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv4(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -379,9 +372,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4.PacketsSent
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[pkt.NICID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+
if !p.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return nil
}
@@ -435,13 +436,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
mtu := int(route.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
+ const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
+ if mtu > maxIPData {
+ mtu = maxIPData
}
- headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
+ available := mtu - header.ICMPv4MinimumSize
- if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize {
return nil
}
@@ -464,36 +465,36 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
payload.CapLength(payloadLen)
icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
+ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize,
Data: payload,
})
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter *tcpip.StatCounter
+ var counter tcpip.MultiCounterStat
switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonProtoUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonParamProblem:
icmpHdr.SetType(header.ICMPv4ParamProblem)
icmpHdr.SetCode(header.ICMPv4UnusedCode)
icmpHdr.SetPointer(reason.pointer)
- counter = sent.ParamProblem
+ counter = sent.paramProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
@@ -508,7 +509,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
},
icmpPkt,
); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
counter.Increment()
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index d9b5fe6ed..4cd0b3256 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
// saying we were the last host to report is cleared, this action MAY be
@@ -149,51 +149,49 @@ func (igmp *igmpState) init(ep *endpoint) {
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
- stats := igmp.ep.protocol.stack.Stats()
- received := stats.IGMP.PacketsReceived
+ received := igmp.ep.stats.igmp.packetsReceived
headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.IGMP(headerView)
- // Temporarily reset the checksum field to 0 in order to calculate the proper
- // checksum.
- wantChecksum := h.Checksum()
- h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- h.SetChecksum(wantChecksum)
-
- if gotChecksum != wantChecksum {
- received.ChecksumErrors.Increment()
+ // As per RFC 1071 section 1.3,
+ //
+ // To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF {
+ received.checksumErrors.Increment()
return
}
switch h.Type() {
case header.IGMPMembershipQuery:
- received.MembershipQuery.Increment()
+ received.membershipQuery.Increment()
if len(headerView) < header.IGMPQueryMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
case header.IGMPv1MembershipReport:
- received.V1MembershipReport.Increment()
+ received.v1MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPv2MembershipReport:
- received.V2MembershipReport.Increment()
+ received.v2MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPLeaveGroup:
- received.LeaveGroup.Increment()
+ received.leaveGroup.Increment()
// As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
// Report, are ignored in all states"
@@ -201,7 +199,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
// As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
// be silently ignored. New message types may be used by newer versions of
// IGMP, by multicast routing protocols, or other uses."
- received.Unrecognized.Increment()
+ received.unrecognized.Increment()
}
}
@@ -244,7 +242,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
// writePacket assembles and sends an IGMP packet.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -272,18 +270,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ sentStats := igmp.ep.stats.igmp.packetsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sentStats.Dropped.Increment()
+ sentStats.dropped.Increment()
return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sentStats.V1MembershipReport.Increment()
+ sentStats.v1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sentStats.V2MembershipReport.Increment()
+ sentStats.v2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sentStats.LeaveGroup.Increment()
+ sentStats.leaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
@@ -295,7 +293,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// messages.
//
// If the group already exists in the membership map, returns
-// tcpip.ErrDuplicateAddress.
+// *tcpip.ErrDuplicateAddress.
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
@@ -314,13 +312,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
// if required.
//
// Precondition: igmp.ep.mu must be locked.
-func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index bb25a76fe..e5c80699d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,9 +16,9 @@
package ipv4
import (
- "errors"
"fmt"
"math"
+ "reflect"
"sync/atomic"
"time"
@@ -73,6 +73,7 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
+ stats sharedStats
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -114,18 +115,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
e.mu.addressableEndpointState.Init(e)
e.mu.igmp.init(e)
e.mu.Unlock()
+
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+
+ stackStats := p.stack.Stats()
+ e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP)
+ e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4)
+ e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP)
+
+ p.mu.Lock()
+ p.mu.eps[nic.ID()] = e
+ p.mu.Unlock()
+
return e
}
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, nicID)
+}
+
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -193,7 +212,9 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
@@ -202,7 +223,9 @@ func (e *endpoint) disableLocked() {
e.mu.igmp.softLeaveAll()
// The address may have already been removed.
- if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
@@ -237,7 +260,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error {
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -245,12 +268,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
}
hdrLen += optLen
if hdrLen > header.IPv4MaximumHeaderSize {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := pkt.Size()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
@@ -275,7 +298,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -295,17 +318,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return err
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
- e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
+ e.stats.ip.IPTablesOutputDropped.Increment()
return nil
}
@@ -334,7 +357,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -349,35 +372,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ stats := e.stats.ip
+
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
- r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ stats.PacketsSent.IncrementBy(uint64(sent))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
- r.Stats().IP.PacketsSent.Increment()
+ stats.PacketsSent.Increment()
return nil
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
@@ -385,6 +410,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ stats := e.stats.ip
+
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return 0, err
@@ -392,7 +419,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
@@ -400,7 +427,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
@@ -413,21 +440,21 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
- r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
@@ -451,36 +478,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ stats.PacketsSent.IncrementBy(uint64(n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
}
n++
}
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
hdrLen := header.IPv4(h).HeaderLength()
if hdrLen < header.IPv4MinimumSize {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
h, ok = pkt.Data.PullUp(int(hdrLen))
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv4(h)
@@ -518,14 +545,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// wire format. We also want to check if the header's fields are valid before
// sending the packet.
if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
ttl := h.TTL()
if ttl == 0 {
@@ -545,7 +572,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -577,19 +604,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- stats.IP.PacketsReceived.Increment()
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
if !e.isEnabled() {
- stats.IP.DisabledPacketsReceived.Increment()
+ stats.DisabledPacketsReceived.Increment()
return
}
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesPreroutingDropped.Increment()
+ stats.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -601,11 +630,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
- stats := e.protocol.stack.Stats()
+ stats := e.stats
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -631,7 +660,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -643,7 +672,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// be one of its own IP addresses (but not a broadcast or
// multicast address).
if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
// Make sure the source address is not a subnet-local broadcast address.
@@ -651,7 +680,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
subnet := addressEndpoint.Subnet()
addressEndpoint.DecRef()
if subnet.IsBroadcast(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
}
@@ -664,7 +693,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
+ stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -674,9 +703,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesInputDropped.Increment()
+ stats.ip.IPTablesInputDropped.Increment()
return
}
@@ -684,8 +714,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -698,8 +728,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
@@ -720,8 +750,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt,
)
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
if !ready {
@@ -734,7 +764,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
- stats.IP.PacketsDelivered.Increment()
+ stats.ip.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
@@ -755,19 +785,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{})
- if err != nil {
- switch {
- case
- errors.Is(err, header.ErrIPv4OptDuplicate),
- errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
- errors.Is(err, errIPv4RecordRouteOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -800,10 +824,12 @@ func (e *endpoint) Close() {
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+
+ e.protocol.forgetEndpoint(e.nic.ID())
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -815,7 +841,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
@@ -872,7 +898,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -881,9 +907,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.igmp.joinGroup(addr)
@@ -891,7 +917,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -900,7 +926,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.igmp.leaveGroup(addr)
}
@@ -911,6 +937,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
return e.mu.igmp.isInGroup(addr)
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -918,6 +949,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
+ mu struct {
+ sync.RWMutex
+
+ // eps is keyed by NICID to allow protocol methods to retrieve an endpoint
+ // when handling a packet, by looking at which NIC handled the packet.
+ eps map[tcpip.NICID]*endpoint
+ }
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -960,24 +999,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1023,9 +1062,9 @@ func (p *protocol) SetForwarding(v bool) {
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
-func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv4MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
@@ -1033,7 +1072,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro
// The maximal internet header is 60 octets, and a typical internet header
// is 20 octets, allowing a margin for headers of higher level protocols.
if networkHeaderSize > header.IPv4MaximumHeaderSize {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU
@@ -1095,6 +1134,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
options: opts,
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ p.mu.eps = make(map[tcpip.NICID]*endpoint)
return p
}
}
@@ -1192,16 +1232,9 @@ func (*optionUsageEcho) actions() optionActions {
}
}
-var (
- errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
- errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
- errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
- errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
-)
-
// handleTimestamp does any required processing on a Timestamp option
// in place.
-func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) *header.IPv4OptParameterProblem {
flags := tsOpt.Flags()
var entrySize uint8
switch flags {
@@ -1212,7 +1245,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
header.IPv4OptionTimestampWithPredefinedIPFlag:
entrySize = header.IPv4OptionTimestampWithAddrSize
default:
- return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSOFLWAndFLGOffset,
+ NeedICMP: true,
+ }
}
pointer := tsOpt.Pointer()
@@ -1220,7 +1256,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
// Since the pointer is 1 based, and the header is 4 bytes long the
// pointer must point beyond the header therefore 4 or less is bad.
if pointer <= header.IPv4OptionTimestampHdrLength {
- return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSPointerOffset,
+ NeedICMP: true,
+ }
}
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
@@ -1254,14 +1293,17 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
// timestamp, but the overflow count is incremented by one.
if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
// By definition we have nothing to do.
- return 0, nil
+ return nil
}
if tsOpt.IncOverflow() != 0 {
- return 0, nil
+ return nil
}
// The overflow count is also full.
- return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSOFLWAndFLGOffset,
+ NeedICMP: true,
+ }
}
if nextSlot+entrySize > dataLength {
// The data area isn't full but there isn't room for a new entry.
@@ -1280,32 +1322,36 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
if dataLength%entrySize != 0 {
// The Data section size should be a multiple of the expected
// timestamp entry size.
- return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: false,
+ }
}
// If the size is OK, the pointer must be corrupted.
}
- return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSPointerOffset,
+ NeedICMP: true,
+ }
}
if usage.actions().timestamp == optionProcess {
tsOpt.UpdateTimestamp(localAddress, clock)
}
- return 0, nil
+ return nil
}
-var (
- errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
- errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
-)
-
// handleRecordRoute checks and processes a Record route option. It is much
// like the timestamp type 1 option, but without timestamps. The passed in
// address is stored in the option in the correct spot if possible.
-func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) *header.IPv4OptParameterProblem {
optlen := rrOpt.Size()
if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
- return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: true,
+ }
}
pointer := rrOpt.Pointer()
@@ -1315,7 +1361,10 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// Since the pointer is 1 based, and the header is 3 bytes long the
// pointer must point beyond the header therefore 3 or less is bad.
if pointer <= header.IPv4OptionRecordRouteHdrLength {
- return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptRRPointerOffset,
+ NeedICMP: true,
+ }
}
// RFC 791 page 21 says
@@ -1332,7 +1381,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
if pointer > optlen {
- return 0, nil
+ return nil
}
// The data area isn't full but there isn't room for a new entry.
@@ -1357,17 +1406,23 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// }
if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
// Length is bad, not on integral number of slots.
- return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: true,
+ }
}
// If not length, the fault must be with the pointer.
}
- return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptRRPointerOffset,
+ NeedICMP: true,
+ }
}
if usage.actions().recordRoute == optionVerify {
- return 0, nil
+ return nil
}
rrOpt.StoreAddress(localAddress)
- return 0, nil
+ return nil
}
// processIPOptions parses the IPv4 options and produces a new set of options
@@ -1378,8 +1433,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// - The location of an error if there was one (or 0 if no error)
// - If there is an error, information as to what it was was.
// - The replacement option set.
-func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
- stats := e.protocol.stack.Stats()
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, *header.IPv4OptParameterProblem) {
+ stats := e.stats.ip
opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
@@ -1392,21 +1447,23 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
// This will need tweaking when we start really forwarding packets
// as we may need to get two addresses, for rx and tx interfaces.
// We will also have to take usage into account.
- prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
+ prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
localAddress := prefixedAddress.Address
- if err != nil {
+ if !ok {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
- return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ return nil, &header.IPv4OptParameterProblem{
+ NeedICMP: false,
+ }
}
localAddress = dstAddr
}
for {
- option, done, err := optIter.Next()
- if done || err != nil {
- return optIter.ErrCursor, optIter.Finalize(), err
+ option, done, optProblem := optIter.Next()
+ if done || optProblem != nil {
+ return optIter.Finalize(), optProblem
}
optType := option.Type()
if optType == header.IPv4OptionNOPType {
@@ -1415,44 +1472,47 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
}
if optType == header.IPv4OptionListEndType {
optIter.PushNOPOrEnd(optType)
- return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ return optIter.Finalize(), nil
}
// check for repeating options (multiple NOPs are OK)
if seenOptions[optType] {
- return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ return nil, &header.IPv4OptParameterProblem{
+ Pointer: optIter.ErrCursor,
+ NeedICMP: true,
+ }
}
seenOptions[optType] = true
optLen := int(option.Size())
switch option := option.(type) {
case *header.IPv4OptionTimestamp:
- stats.IP.OptionTSReceived.Increment()
+ stats.OptionTSReceived.Increment()
if usage.actions().timestamp != optionRemove {
clock := e.protocol.stack.Clock()
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
- offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
- if err != nil {
- return optIter.ErrCursor + offset, nil, err
+ if optProblem := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage); optProblem != nil {
+ optProblem.Pointer += optIter.ErrCursor
+ return nil, optProblem
}
optIter.ConsumeBuffer(optLen)
}
case *header.IPv4OptionRecordRoute:
- stats.IP.OptionRRReceived.Increment()
+ stats.OptionRRReceived.Increment()
if usage.actions().recordRoute != optionRemove {
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
- offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
- if err != nil {
- return optIter.ErrCursor + offset, nil, err
+ if optProblem := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage); optProblem != nil {
+ optProblem.Pointer += optIter.ErrCursor
+ return nil, optProblem
}
optIter.ConsumeBuffer(optLen)
}
default:
- stats.IP.OptionUnknownReceived.Increment()
+ stats.OptionUnknownReceived.Increment()
if usage.actions().unknown == optionPass {
newBuffer := optIter.RemainingBuffer()[:optLen]
// Arguments already heavily checked.. ignore result.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index a9e137c24..ed5899f0b 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -78,8 +78,11 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Cannot connect using a broadcast address as the source.
- if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ {
+ err := ep.Connect(randomAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{})
+ }
}
// However, we can bind to a broadcast address to listen.
@@ -270,6 +273,11 @@ func TestIPv4Sanity(t *testing.T) {
nicID = 1
randomSequence = 123
randomIdent = 42
+ // In some cases Linux sets the error pointer to the start of the option
+ // (offset 0) instead of the actual wrong value, which is the length byte
+ // (offset 1). For compatibility we must do the same. Use this constant
+ // to indicate where this happens.
+ pointerOffsetForInvalidLength = 0
)
var (
ipv4Addr = tcpip.AddressWithPrefix{
@@ -439,6 +447,21 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{},
},
{
+ name: "bad option - no length",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 1, 1, 1, 68,
+ // ^-start of timestamp.. but no length..
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ },
+ {
name: "bad option - length 0",
maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
@@ -448,7 +471,27 @@ func TestIPv4Sanity(t *testing.T) {
// ^
1, 2, 3, 4,
},
- shouldFail: true,
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
+ },
+ {
+ name: "bad option - length 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 1, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
name: "bad option - length big",
@@ -462,7 +505,11 @@ func TestIPv4Sanity(t *testing.T) {
// space is not possible. (Second byte)
1, 2, 3, 4,
},
- shouldFail: true,
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
// This tests for some linux compatible behaviour.
@@ -484,7 +531,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
{
name: "multiple type 0 with room",
@@ -589,7 +636,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
{
name: "valid timestamp pointer",
@@ -624,7 +671,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
// End of option list with illegal option after it, which should be ignored.
{
@@ -636,24 +683,31 @@ func TestIPv4Sanity(t *testing.T) {
68, 12, 13, 0x11,
192, 168, 1, 12,
1, 2, 3, 4,
- 0, 10, 3, 99,
+ 0, 10, 3, 99, // EOL followed by junk
},
replyOptions: header.IPv4Options{
68, 12, 13, 0x21,
192, 168, 1, 12,
1, 2, 3, 4,
- 0, 0, 0, 0, // 3 bytes unknown option
- }, // ^ End of options hides following bytes.
+ 0, // End of Options hides following bytes.
+ 0, 0, 0, // 3 bytes unknown option removed.
+ },
},
{
- // Timestamp with a size too small.
+ // Timestamp with a size much too small.
name: "timestamp truncated",
maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
- options: header.IPv4Options{68, 1, 0, 0},
- // ^ Smallest possible is 8.
- shouldFail: true,
+ options: header.IPv4Options{
+ 68, 1, 0, 0,
+ // ^ Smallest possible is 8. Linux points at the 68.
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
name: "single record route with room",
@@ -751,7 +805,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
// Pointer must be 4 or more as it must point past the 3 byte header
@@ -769,7 +823,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
// Pointer must be 4 or more as it must point past the 3 byte header
@@ -808,8 +862,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
- replyOptions: header.IPv4Options{},
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
name: "duplicate record route",
@@ -828,7 +881,6 @@ func TestIPv4Sanity(t *testing.T) {
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
paramProblemPointer: header.IPv4MinimumSize + 7,
- replyOptions: header.IPv4Options{},
},
}
@@ -884,7 +936,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
-
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: test.transportProtocol,
@@ -1328,8 +1379,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -1338,8 +1389,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -1348,8 +1399,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -1358,8 +1409,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag MTU smaller than header",
@@ -1368,8 +1419,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize: 500,
allowPackets: 0,
outgoingErrors: 4,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than IPv4 minimum MTU",
@@ -1379,7 +1430,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -1393,8 +1444,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -2427,8 +2478,9 @@ func TestReceiveFragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2506,11 +2558,11 @@ func TestWriteStats(t *testing.T) {
// Parameterize the tests to run with both WritePacket and WritePackets.
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2522,7 +2574,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2532,7 +2584,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -2608,7 +2660,7 @@ func (*limitedMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
if lm.limit == 0 {
return true, false
}
diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go
new file mode 100644
index 000000000..bee72c649
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to the IPv4 protocol family.
+type Stats struct {
+ // IP holds IPv4 statistics.
+ IP tcpip.IPStats
+
+ // IGMP holds IGMP statistics.
+ IGMP tcpip.IGMPStats
+
+ // ICMP holds ICMPv4 statistics.
+ ICMP tcpip.ICMPv4Stats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+// IPStats implements stack.IPNetworkEndointStats
+func (s *Stats) IPStats() *tcpip.IPStats {
+ return &s.IP
+}
+
+type sharedStats struct {
+ localStats Stats
+ ip ip.MultiCounterIPStats
+ icmp multiCounterICMPv4Stats
+ igmp multiCounterIGMPStats
+}
+
+// LINT.IfChange(multiCounterICMPv4PacketStats)
+
+type multiCounterICMPv4PacketStats struct {
+ echo tcpip.MultiCounterStat
+ echoReply tcpip.MultiCounterStat
+ dstUnreachable tcpip.MultiCounterStat
+ srcQuench tcpip.MultiCounterStat
+ redirect tcpip.MultiCounterStat
+ timeExceeded tcpip.MultiCounterStat
+ paramProblem tcpip.MultiCounterStat
+ timestamp tcpip.MultiCounterStat
+ timestampReply tcpip.MultiCounterStat
+ infoRequest tcpip.MultiCounterStat
+ infoReply tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) {
+ m.echo.Init(a.Echo, b.Echo)
+ m.echoReply.Init(a.EchoReply, b.EchoReply)
+ m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
+ m.srcQuench.Init(a.SrcQuench, b.SrcQuench)
+ m.redirect.Init(a.Redirect, b.Redirect)
+ m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
+ m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
+ m.timestamp.Init(a.Timestamp, b.Timestamp)
+ m.timestampReply.Init(a.TimestampReply, b.TimestampReply)
+ m.infoRequest.Init(a.InfoRequest, b.InfoRequest)
+ m.infoReply.Init(a.InfoReply, b.InfoReply)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats)
+
+// LINT.IfChange(multiCounterICMPv4SentPacketStats)
+
+type multiCounterICMPv4SentPacketStats struct {
+ multiCounterICMPv4PacketStats
+ dropped tcpip.MultiCounterStat
+ rateLimited tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+ m.rateLimited.Init(a.RateLimited, b.RateLimited)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats)
+
+type multiCounterICMPv4ReceivedPacketStats struct {
+ multiCounterICMPv4PacketStats
+ invalid tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4Stats)
+
+type multiCounterICMPv4Stats struct {
+ packetsSent multiCounterICMPv4SentPacketStats
+ packetsReceived multiCounterICMPv4ReceivedPacketStats
+}
+
+func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4Stats)
+
+// LINT.IfChange(multiCounterIGMPPacketStats)
+
+type multiCounterIGMPPacketStats struct {
+ membershipQuery tcpip.MultiCounterStat
+ v1MembershipReport tcpip.MultiCounterStat
+ v2MembershipReport tcpip.MultiCounterStat
+ leaveGroup tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) {
+ m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery)
+ m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport)
+ m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport)
+ m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPPacketStats)
+
+// LINT.IfChange(multiCounterIGMPSentPacketStats)
+
+type multiCounterIGMPSentPacketStats struct {
+ multiCounterIGMPPacketStats
+ dropped tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats)
+
+// LINT.IfChange(multiCounterIGMPReceivedPacketStats)
+
+type multiCounterIGMPReceivedPacketStats struct {
+ multiCounterIGMPPacketStats
+ invalid tcpip.MultiCounterStat
+ checksumErrors tcpip.MultiCounterStat
+ unrecognized tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+ m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors)
+ m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats)
+
+// LINT.IfChange(multiCounterIGMPStats)
+
+type multiCounterIGMPStats struct {
+ packetsSent multiCounterIGMPSentPacketStats
+ packetsReceived multiCounterIGMPReceivedPacketStats
+}
+
+func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPStats)
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
new file mode 100644
index 000000000..b28e7dcde
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkInterface
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ nic := testInterface{nicID: 1}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
+ }
+
+ ep.Close()
+
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // expected to be bound by a MultiCounterStat.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index afa45aefe..0c5f8d683 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -10,6 +10,7 @@ go_library(
"ipv6.go",
"mld.go",
"ndp.go",
+ "stats.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6ee162713..7298bd061 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -125,15 +125,14 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
- stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6.PacketsSent
- received := stats.V6.PacketsReceived
+ sent := e.stats.icmp.packetsSent
+ received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.ICMPv6(v)
@@ -147,7 +146,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
payload := pkt.Data.Clone(nil)
payload.TrimFront(len(h))
if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -165,10 +164,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
- received.PacketTooBig.Increment()
+ received.packetTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
@@ -179,10 +178,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
- received.DstUnreachable.Increment()
+ received.dstUnreachable.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
@@ -194,9 +193,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6NeighborSolicit:
- received.NeighborSolicit.Increment()
+ received.neighborSolicit.Increment()
if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -210,7 +209,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
// address.
if header.IsV6MulticastAddress(targetAddr) {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -238,7 +237,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
}
@@ -263,13 +264,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
if err != nil {
// Options are not valid as per the wire format, silently drop the
// packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok = getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
}
@@ -282,16 +283,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
unspecifiedSource := srcAddr == header.IPv6Any
if len(sourceLinkAddr) == 0 {
if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
} else if unspecifiedSource {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
} else if e.nud != nil {
e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -301,7 +302,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// - If the IP source address is the unspecified address, the IP
// destination address is a solicited-node multicast address.
if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -379,15 +380,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.NeighborAdvert.Increment()
+ sent.neighborAdvert.Increment()
case header.ICMPv6NeighborAdvert:
- received.NeighborAdvert.Increment()
+ received.neighborAdvert.Increment()
if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -414,16 +415,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ return
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
- return
}
it, err := na.Options().Iter(false /* check */)
if err != nil {
// If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -438,7 +441,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// DAD.
targetLinkAddr, ok := getTargetLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
+ return
+ }
+
+ // As per RFC 4861 section 7.1.2:
+ // A node MUST silently discard any received Neighbor Advertisement
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP Destination Address is a multicast address the
+ // Solicited flag is zero.
+ if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() {
+ received.invalid.Increment()
return
}
@@ -446,7 +460,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// address cache with the link address for the target of the message.
if e.nud == nil {
if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr)
}
return
}
@@ -458,10 +472,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
})
case header.ICMPv6EchoRequest:
- received.EchoRequest.Increment()
+ received.echoRequest.Increment()
icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -493,27 +507,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
TTL: r.DefaultTTL(),
TOS: stack.DefaultTOS,
}, replyPkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.EchoReply.Increment()
+ sent.echoReply.Increment()
case header.ICMPv6EchoReply:
- received.EchoReply.Increment()
+ received.echoReply.Increment()
if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
- received.TimeExceeded.Increment()
+ received.timeExceeded.Increment()
case header.ICMPv6ParamProblem:
- received.ParamProblem.Increment()
+ received.paramProblem.Increment()
case header.ICMPv6RouterSolicit:
- received.RouterSolicit.Increment()
+ received.routerSolicit.Increment()
//
// Validate the RS as per RFC 4861 section 6.1.1.
@@ -521,7 +535,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the NDP payload of sufficient size to hold a Router Solictation?
if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -530,7 +544,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the networking stack operating as a router?
if !stack.Forwarding(ProtocolNumber) {
// ... No, silently drop the packet.
- received.RouterOnlyPacketsDroppedByHost.Increment()
+ received.routerOnlyPacketsDroppedByHost.Increment()
return
}
@@ -540,13 +554,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
it, err := rs.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok := getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -557,7 +571,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// NOT be included when the source IP address is the unspecified address.
// Otherwise, it SHOULD be included on link layers that have addresses.
if srcAddr == header.IPv6Any {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -569,7 +583,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ received.routerAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
@@ -577,7 +591,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the NDP payload of sufficient size to hold a Router Advertisement?
if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -586,7 +600,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -596,13 +610,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
it, err := ra.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok := getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -638,26 +652,26 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// link-layer address be modified due to receiving one of the above
// messages, the state SHOULD also be set to STALE to provide prompt
// verification that the path to the new link-layer address is working."
- received.RedirectMsg.Increment()
+ received.redirectMsg.Increment()
if !isNDPValid() {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
- received.MulticastListenerQuery.Increment()
+ received.multicastListenerQuery.Increment()
case header.ICMPv6MulticastListenerReport:
- received.MulticastListenerReport.Increment()
+ received.multicastListenerReport.Increment()
case header.ICMPv6MulticastListenerDone:
- received.MulticastListenerDone.Increment()
+ received.multicastListenerDone.Increment()
default:
panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -676,7 +690,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
default:
- received.Unrecognized.Increment()
+ received.unrecognized.Increment()
}
}
@@ -688,26 +702,39 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
+ nicID := nic.ID()
+
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[nicID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
remoteAddr := targetAddr
if len(remoteLinkAddr) == 0 {
remoteAddr = header.SolicitedNodeAddr(targetAddr)
remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ if len(localAddr) == 0 {
+ addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return &tcpip.ErrNetworkUnreachable{}
+ }
+
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 {
+ return &tcpip.ErrBadLocalAddress{}
}
- defer r.Release()
- r.ResolveWith(remoteLinkAddr)
optsSerializer := header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
@@ -715,18 +742,23 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
- packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, pkt); err != nil {
- stat.Dropped.Increment()
+ }, header.IPv6ExtHdrSerializer{}); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
+
+ stat := netEP.stats.icmp.packetsSent
+
+ if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ stat.dropped.Increment()
return err
}
- stat.NeighborSolicit.Increment()
+ stat.neighborSolicit.Increment()
return nil
}
@@ -796,7 +828,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv6(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -863,10 +895,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- stats := p.stack.Stats().ICMP
- sent := stats.V6.PacketsSent
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[pkt.NICID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+
if !p.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return nil
}
@@ -897,11 +936,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
mtu := int(route.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
+ const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize
+ if mtu > maxIPv6Data {
+ mtu = maxIPv6Data
}
- headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
- available := int(mtu) - headerLen
+ available := mtu - header.ICMPv6ErrorHeaderSize
if available < header.IPv6MinimumSize {
return nil
}
@@ -915,31 +954,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
+ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize,
Data: payload,
})
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter *tcpip.StatCounter
+ var counter tcpip.MultiCounterStat
switch reason := reason.(type) {
case *icmpReasonParameterProblem:
icmpHdr.SetType(header.ICMPv6ParamProblem)
icmpHdr.SetCode(reason.code)
icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.ParamProblem
+ counter = sent.paramProblem
case *icmpReasonPortUnreachable:
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonHopLimitExceeded:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
@@ -953,7 +992,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
},
newPkt,
); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
counter.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index b1e6a70a2..db1c2e663 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "bytes"
"context"
"net"
"reflect"
@@ -22,6 +23,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -77,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
return nil
}
@@ -91,16 +93,11 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
-type stubLinkAddressCache struct {
- stack.LinkAddressCache
-}
+var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil)
-func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
- return 0
-}
+type stubLinkAddressCache struct{}
-func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
-}
+func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {}
type stubNUDHandler struct {
probeCount int
@@ -148,15 +145,15 @@ func (*testInterface) Promiscuous() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var r stack.RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
@@ -643,7 +640,6 @@ func TestLinkResolution(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
// doesn't provoke NDP discovery.
@@ -653,8 +649,12 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
- if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
- t.Fatalf("ep.Write(_): %s", err)
+ {
+ var r bytes.Reader
+ r.Reset(hdr.View())
+ if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
+ t.Fatalf("ep.Write(_): %s", err)
+ }
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -1283,7 +1283,7 @@ func TestLinkAddressRequest(t *testing.T) {
localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
expectedRemoteAddr tcpip.Address
expectedRemoteLinkAddr tcpip.LinkAddress
}{
@@ -1321,79 +1321,80 @@ func TestLinkAddressRequest(t *testing.T) {
name: "Unicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Multicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Unicast with no local address available",
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
{
name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
}
for _, test := range tests {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- p := s.NetworkProtocolInstance(ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
- }
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
- linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
}
- }
- // We pass a test network interface to LinkAddressRequest with the same NIC
- // ID and link endpoint used by the NIC we created earlier so that we can
- // mock a link address request and observe the packets sent to the link
- // endpoint even though the stack uses the real NIC.
- if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
- }
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID})
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
+ }
- if test.expectedErr != nil {
- return
- }
+ if test.expectedErr != nil {
+ return
+ }
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
- }
- if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
- }
- if pkt.Route.LocalAddress != lladdr1 {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
- }
- checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
- checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedRemoteAddr),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
- ))
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ var want stack.RouteInfo
+ want.NetProto = ProtocolNumber
+ want.RemoteLinkAddress = test.expectedRemoteLinkAddr
+ if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" {
+ t.Errorf("route info mismatch (-want +got):\n%s", diff)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedRemoteAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index ae4a8f508..94043ed4e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash/fnv"
"math"
+ "reflect"
"sort"
"sync/atomic"
"time"
@@ -177,6 +178,7 @@ type endpoint struct {
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
+ stats sharedStats
// enabled is set to 1 when the endpoint is enabled and 0 when it is
// disabled.
@@ -305,17 +307,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
@@ -367,14 +369,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) {
}
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -416,7 +418,7 @@ func (e *endpoint) Enable() *tcpip.Error {
//
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
- var err *tcpip.Error
+ var err tcpip.Error
e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
@@ -497,7 +499,9 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
@@ -553,11 +557,11 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error {
+func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error {
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
@@ -583,7 +587,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta
// fragments left to be processed. The IP header must already be present in the
// original packet. The transport header protocol number is required to avoid
// parsing the IPv6 extension headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
networkHeader := header.IPv6(pkt.NetworkHeader().View())
// TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
@@ -596,13 +600,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// of 8 as per RFC 8200 section 4.5:
// Each complete fragment, except possibly the last ("rightmost") one, is
// an integer multiple of 8 octets long.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
@@ -622,17 +626,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+ if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
return err
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
- e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
+ e.stats.ip.IPTablesOutputDropped.Increment()
return nil
}
@@ -660,7 +664,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -675,36 +679,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
- r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ stats.PacketsSent.IncrementBy(uint64(sent))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
- r.Stats().IP.PacketsSent.Increment()
+ stats.PacketsSent.Increment()
return nil
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
@@ -712,28 +717,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ stats := e.stats.ip
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
+ if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
return 0, err
}
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
if packetMustBeFragmented(pb, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
return nil
}); err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
// Remove the packet that was just fragmented and process the rest.
@@ -743,19 +749,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
- r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
@@ -779,8 +785,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
+ stats.PacketsSent.IncrementBy(uint64(n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -788,17 +794,17 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
n++
}
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv6(h)
@@ -823,14 +829,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// sending the packet.
proto, _, _, _, ok := parse.IPv6(pkt)
if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
hopLimit := h.HopLimit()
if hopLimit <= 1 {
@@ -852,7 +858,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -882,19 +888,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- stats.IP.PacketsReceived.Increment()
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
if !e.isEnabled() {
- stats.IP.DisabledPacketsReceived.Increment()
+ stats.DisabledPacketsReceived.Increment()
return
}
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesPreroutingDropped.Increment()
+ stats.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -906,11 +914,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
- stats := e.protocol.stack.Stats()
+ stats := e.stats.ip
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
srcAddr := h.SourceAddress()
@@ -920,7 +928,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Multicast addresses must not be used as source addresses in IPv6
// packets or appear in any Routing header.
if header.IsV6MulticastAddress(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.InvalidSourceAddressesReceived.Increment()
return
}
@@ -930,7 +938,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
+ stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -950,9 +958,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesInputDropped.Increment()
+ stats.IPTablesInputDropped.Increment()
return
}
@@ -962,7 +971,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -986,7 +995,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -1075,8 +1084,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
it, done, err := it.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -1103,8 +1112,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
switch lastHdr.(type) {
case header.IPv6RawPayloadHeader:
default:
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
}
@@ -1112,8 +1121,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
@@ -1126,8 +1135,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// of the fragment, pointing to the Payload Length field of the
// fragment packet.
if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
@@ -1147,8 +1156,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
@@ -1173,8 +1182,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt,
)
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
@@ -1194,7 +1203,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -1244,12 +1253,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
- stats.IP.PacketsDelivered.Increment()
+ stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
e.handleICMP(pkt, hasFragmentHeader)
} else {
- stats.IP.PacketsDelivered.Increment()
+ stats.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -1314,7 +1323,7 @@ func (e *endpoint) Close() {
e.mu.addressableEndpointState.Cleanup()
e.mu.Unlock()
- e.protocol.forgetEndpoint(e)
+ e.protocol.forgetEndpoint(e.nic.ID())
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -1323,7 +1332,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
@@ -1338,7 +1347,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err != nil {
return nil, err
@@ -1367,13 +1376,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
return e.removePermanentEndpointLocked(addressEndpoint, true)
@@ -1383,7 +1392,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
unicast := header.IsV6UnicastAddress(addr.Address)
if unicast {
@@ -1408,12 +1417,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
+ err := e.leaveGroupLocked(snmc)
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); ok {
+ err = nil
}
-
- return nil
+ return err
}
// hasPermanentAddressLocked returns true if the endpoint has a permanent
@@ -1623,7 +1632,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -1632,9 +1641,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.mld.joinGroup(addr)
@@ -1642,7 +1651,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -1651,7 +1660,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.mld.leaveGroup(addr)
}
@@ -1662,6 +1671,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
return e.mu.mld.isInGroup(addr)
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1673,7 +1687,9 @@ type protocol struct {
mu struct {
sync.RWMutex
- eps map[*endpoint]struct{}
+ // eps is keyed by NICID to allow protocol methods to retrieve an endpoint
+ // when handling a packet, by looking at which NIC handled the packet.
+ eps map[tcpip.NICID]*endpoint
}
ids []uint32
@@ -1730,37 +1746,42 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e.mu.mld.init(e)
e.mu.Unlock()
+ stackStats := p.stack.Stats()
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+ e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP)
+ e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V6)
+
p.mu.Lock()
defer p.mu.Unlock()
- p.mu.eps[e] = struct{}{}
+ p.mu.eps[nic.ID()] = e
return e
}
-func (p *protocol) forgetEndpoint(e *endpoint) {
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
- delete(p.mu.eps, e)
+ delete(p.mu.eps, nicID)
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1814,7 +1835,7 @@ func (p *protocol) SetForwarding(v bool) {
return
}
- for ep := range p.mu.eps {
+ for _, ep := range p.mu.eps {
ep.transitionForwarding(v)
}
}
@@ -1823,9 +1844,9 @@ func (p *protocol) SetForwarding(v bool) {
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
// which includes the length of the extension headers.
-func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv6MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 7112 section 5, we should discard packets if their IPv6 header
@@ -1836,7 +1857,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro
// bytes ensures that the header chain length does not exceed the IPv6
// minimum MTU.
if networkHeadersLen > header.IPv6MinimumMTU {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU - uint32(networkHeadersLen)
@@ -1906,7 +1927,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
hashIV: hashIV,
}
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
- p.mu.eps = make(map[*endpoint]struct{})
+ p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
return p
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index b65c9d060..8248052a3 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -21,6 +21,7 @@ import (
"io/ioutil"
"math"
"net"
+ "reflect"
"testing"
"github.com/google/go-cmp/cmp"
@@ -370,12 +371,10 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
}
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok {
+ t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber)
+ } else if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr)
}
})
}
@@ -997,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}
// Should not have any more UDP packets.
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -1989,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2473,11 +2474,11 @@ func TestWriteStats(t *testing.T) {
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2489,7 +2490,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2499,7 +2500,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2574,7 +2575,7 @@ func (*limitedMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
if lm.limit == 0 {
return true, false
}
@@ -2582,30 +2583,44 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
return false, false
}
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
func TestClearEndpointFromProtocolOnClose(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
- {
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[ep]
- proto.mu.Unlock()
- if !hasEP {
- t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
- }
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
}
ep.Close()
- {
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[ep]
- proto.mu.Unlock()
- if hasEP {
- t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
- }
+ proto.mu.Lock()
+ _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEndpointAfterClose {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
}
}
@@ -2819,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -2829,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -2839,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -2849,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than transport header",
@@ -2860,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrMessageTooLong,
+ wantError: &tcpip.ErrMessageTooLong{},
},
{
description: "Error when MTU is smaller than IPv6 minimum MTU",
@@ -2870,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -2884,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -3053,3 +3068,22 @@ func TestForwarding(t *testing.T) {
})
}
}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // supposed to be bound.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index ec54d88cc..2cc0dfebd 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
_, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
@@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
-// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+// If the group is already joined, returns *tcpip.ErrDuplicateAddress.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
@@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
// required.
//
// Precondition: mld.ep.mu must be locked.
-func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of MLD, but remains
@@ -166,14 +166,14 @@ func (mld *mldState) sendQueuedReports() {
// writePacket assembles and sends an MLD packet.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
- sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- var mldStat *tcpip.StatCounter
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) {
+ sentStats := mld.ep.stats.icmp.packetsSent
+ var mldStat tcpip.MultiCounterStat
switch mldType {
case header.ICMPv6MulticastListenerReport:
- mldStat = sentStats.MulticastListenerReport
+ mldStat = sentStats.multicastListenerReport
case header.ICMPv6MulticastListenerDone:
- mldStat = sentStats.MulticastListenerDone
+ mldStat = sentStats.multicastListenerDone
default:
panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
}
@@ -249,14 +249,14 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
Data: buffer.View(icmp).ToVectorisedView(),
})
- if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
}, extensionHeaders); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sentStats.Dropped.Increment()
+ sentStats.dropped.Increment()
return false, err
}
mldStat.Increment()
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 1d8fee50b..d7dde1767 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -241,7 +241,7 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error)
// OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
@@ -614,10 +614,10 @@ type slaacPrefixState struct {
// tentative.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
- return tcpip.ErrAddressFamilyNotSupported
+ return &tcpip.ErrAddressFamilyNotSupported{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
@@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
dadDone := remaining == 0
- var err *tcpip.Error
+ var err tcpip.Error
if !dadDone {
err = ndp.sendDADPacket(addr, addressEndpoint)
}
@@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -731,8 +731,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ sent := ndp.ep.stats.icmp.packetsSent
+ if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
@@ -740,10 +740,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
}
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
- sent.NeighborSolicit.Increment()
+ sent.neighborSolicit.Increment()
+
return nil
}
@@ -1855,20 +1856,20 @@ func (ndp *ndpState) startSolicitingRouters() {
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ sent := ndp.ep.stats.icmp.packetsSent
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
- sent.RouterSolicit.Increment()
+ sent.routerSolicit.Increment()
remaining--
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index b1a5a5510..1d38b8b05 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -90,7 +90,7 @@ type testNDPDispatcher struct {
addr tcpip.Address
}
-func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
}
func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
@@ -162,10 +162,15 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
}
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+type linkResolutionResult struct {
+ linkAddr tcpip.LinkAddress
+ ok bool
+}
+
+// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
-func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -231,45 +236,40 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ wantInvalid := uint64(0)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
}
} else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
}
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
}
-// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NS message with the Source Link Layer Address
// option results in a new entry in the link address cache for the sender of
// the message.
-func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -382,7 +382,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
}
}
-func TestNeighorSolicitationResponse(t *testing.T) {
+func TestNeighborSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
remoteAddr := lladdr1
@@ -640,18 +640,12 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatal("expected an NDP NS response")
}
- if p.Route.LocalAddress != nicAddr {
- t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr)
- }
- if p.Route.LocalLinkAddress != nicLinkAddr {
- t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
- }
respNSDst := header.SolicitedNodeAddr(test.nsSrc)
- if p.Route.RemoteAddress != respNSDst {
- t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
- }
- if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ var want stack.RouteInfo
+ want.NetProto = ProtocolNumber
+ want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
+ if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
+ t.Errorf("route info mismatch (-want +got):\n%s", diff)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -727,10 +721,10 @@ func TestNeighorSolicitationResponse(t *testing.T) {
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
// valid NDP NA message with the Target Link Layer Address option results in a
// new entry in the link address cache for the target of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -803,45 +797,40 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ wantInvalid := uint64(0)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
}
} else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
}
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NA message with the Target Link Layer Address
// option does not result in a new entry in the neighbor cache for the target
// of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -1183,6 +1172,118 @@ func TestNDPValidation(t *testing.T) {
}
+// TestNeighborAdvertisementValidation tests that the NIC validates received
+// Neighbor Advertisements.
+//
+// In particular, if the IP Destination Address is a multicast address, and the
+// Solicited flag is not zero, the Neighbor Advertisement is invalid and should
+// be discarded.
+func TestNeighborAdvertisementValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ ipDstAddr tcpip.Address
+ solicitedFlag bool
+ valid bool
+ }{
+ {
+ name: "Multicast IP destination address with Solicited flag set",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: true,
+ valid: false,
+ },
+ {
+ name: "Multicast IP destination address with Solicited flag unset",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: false,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag set",
+ ipDstAddr: lladdr0,
+ solicitedFlag: true,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag unset",
+ ipDstAddr: lladdr0,
+ solicitedFlag: false,
+ valid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
+ na.SetTargetAddress(lladdr1)
+ na.SetSolicitedFlag(test.solicitedFlag)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: test.ipDstAddr,
+ })
+
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ rxNA := stats.NeighborAdvert
+
+ if got := rxNA.Value(); got != 0 {
+ t.Fatalf("got rxNA = %d, want = 0", got)
+ }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ if got := rxNA.Value(); got != 1 {
+ t.Fatalf("got rxNA = %d, want = 1", got)
+ }
+ var wantInvalid uint64 = 1
+ if test.valid {
+ wantInvalid = 0
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Fatalf("got invalid = %d, want = %d", got, wantInvalid)
+ }
+ // As per RFC 4861 section 7.2.5:
+ // When a valid Neighbor Advertisement is received ...
+ // If no entry exists, the advertisement SHOULD be silently discarded.
+ // There is no need to create an entry if none exists, since the
+ // recipient has apparently not initiated any communication with the
+ // target.
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+ })
+ }
+}
+
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
new file mode 100644
index 000000000..0839be3cd
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to the IPv6 protocol family.
+type Stats struct {
+ // IP holds IPv6 statistics.
+ IP tcpip.IPStats
+
+ // ICMP holds ICMPv6 statistics.
+ ICMP tcpip.ICMPv6Stats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+// IPStats implements stack.IPNetworkEndointStats
+func (s *Stats) IPStats() *tcpip.IPStats {
+ return &s.IP
+}
+
+type sharedStats struct {
+ localStats Stats
+ ip ip.MultiCounterIPStats
+ icmp multiCounterICMPv6Stats
+}
+
+// LINT.IfChange(multiCounterICMPv6PacketStats)
+
+type multiCounterICMPv6PacketStats struct {
+ echoRequest tcpip.MultiCounterStat
+ echoReply tcpip.MultiCounterStat
+ dstUnreachable tcpip.MultiCounterStat
+ packetTooBig tcpip.MultiCounterStat
+ timeExceeded tcpip.MultiCounterStat
+ paramProblem tcpip.MultiCounterStat
+ routerSolicit tcpip.MultiCounterStat
+ routerAdvert tcpip.MultiCounterStat
+ neighborSolicit tcpip.MultiCounterStat
+ neighborAdvert tcpip.MultiCounterStat
+ redirectMsg tcpip.MultiCounterStat
+ multicastListenerQuery tcpip.MultiCounterStat
+ multicastListenerReport tcpip.MultiCounterStat
+ multicastListenerDone tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) {
+ m.echoRequest.Init(a.EchoRequest, b.EchoRequest)
+ m.echoReply.Init(a.EchoReply, b.EchoReply)
+ m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
+ m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig)
+ m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
+ m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
+ m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit)
+ m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert)
+ m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit)
+ m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert)
+ m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg)
+ m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery)
+ m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport)
+ m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats)
+
+// LINT.IfChange(multiCounterICMPv6SentPacketStats)
+
+type multiCounterICMPv6SentPacketStats struct {
+ multiCounterICMPv6PacketStats
+ dropped tcpip.MultiCounterStat
+ rateLimited tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) {
+ m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+ m.rateLimited.Init(a.RateLimited, b.RateLimited)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats)
+
+// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats)
+
+type multiCounterICMPv6ReceivedPacketStats struct {
+ multiCounterICMPv6PacketStats
+ unrecognized tcpip.MultiCounterStat
+ invalid tcpip.MultiCounterStat
+ routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) {
+ m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
+ m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
+ m.invalid.Init(a.Invalid, b.Invalid)
+ m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats)
+
+// LINT.IfChange(multiCounterICMPv6Stats)
+
+type multiCounterICMPv6Stats struct {
+ packetsSent multiCounterICMPv6SentPacketStats
+ packetsReceived multiCounterICMPv6ReceivedPacketStats
+}
+
+func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6Stats)
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index d0ffc299a..bd62c4482 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -6,8 +6,10 @@ go_library(
name = "testutil",
srcs = [
"testutil.go",
+ "testutil_unsafe.go",
],
visibility = [
+ "//pkg/tcpip/network/arp:__pkg__",
"//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 3af44991f..f5fa77b65 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -19,6 +19,8 @@ package testutil
import (
"fmt"
"math/rand"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -33,7 +35,7 @@ type MockLinkEndpoint struct {
WrittenPackets []*stack.PacketBuffer
mtu uint32
- err *tcpip.Error
+ err tcpip.Error
allowPackets int
}
@@ -41,7 +43,7 @@ type MockLinkEndpoint struct {
//
// err is the error that will be returned once allowPackets packets are written
// to the endpoint.
-func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
return &MockLinkEndpoint{
mtu: mtu,
err: err,
@@ -62,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if ep.allowPackets == 0 {
return ep.err
}
@@ -72,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
var n int
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -127,3 +129,69 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi
}
return pkt
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go
new file mode 100644
index 000000000..5ff764800
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil_unsafe.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafeExposeUnexportedFields takes a Value and returns a version of it in
+// which even unexported fields can be read and written.
+func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
+ return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 2bad05a2e..57abec5c9 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -18,5 +18,6 @@ go_test(
library = ":ports",
deps = [
"//pkg/tcpip",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index d87193650..11dbdbbcf 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -329,7 +329,7 @@ func NewPortManager() *PortManager {
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
offset := uint32(rand.Int31n(numEphemeralPorts))
return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
}
@@ -348,7 +348,7 @@ func (s *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
if err == nil {
s.incPortHint()
@@ -361,7 +361,7 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
for i := uint32(0); i < count; i++ {
port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
@@ -374,7 +374,7 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
}
- return 0, tcpip.ErrNoPortAvailable
+ return 0, &tcpip.ErrNoPortAvailable{}
}
// IsPortAvailable tests if the given port is available on all given protocols.
@@ -404,7 +404,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// An optional testPort closure can be passed in which if provided will be used
// to test if the picked port can be used. The function should return true if
// the port is safe to use, false otherwise.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -414,17 +414,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// protocols.
if port != 0 {
if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
- return 0, tcpip.ErrPortInUse
+ return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil && !testPort(port) {
s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
- return 0, tcpip.ErrPortInUse
+ return 0, &tcpip.ErrPortInUse{}
}
return port, nil
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
return false, nil
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 4bc949fd8..e70fbb72b 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -18,6 +18,7 @@ import (
"math/rand"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -32,7 +33,7 @@ const (
type portReserveTestAction struct {
port uint16
ip tcpip.Address
- want *tcpip.Error
+ want tcpip.Error
flags Flags
release bool
device tcpip.NICID
@@ -50,19 +51,19 @@ func TestPortReservation(t *testing.T) {
{port: 80, ip: fakeIPAddress, want: nil},
{port: 80, ip: fakeIPAddress1, want: nil},
/* N.B. Order of tests matters! */
- {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
},
},
{
tname: "bind to inaddr any",
actions: []portReserveTestAction{
{port: 22, ip: anyIPAddress, want: nil},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
/* release fakeIPAddress, but anyIPAddress is still inuse */
{port: 22, ip: fakeIPAddress, release: true},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
{port: 22, ip: anyIPAddress, want: nil, release: true},
{port: 22, ip: fakeIPAddress, want: nil},
@@ -80,8 +81,8 @@ func TestPortReservation(t *testing.T) {
{port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
@@ -91,14 +92,14 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
{port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
@@ -107,7 +108,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind twice with device fails",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 3, want: nil},
- {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind to device",
@@ -119,50 +120,50 @@ func TestPortReservation(t *testing.T) {
tname: "bind to device and then without device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind without device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, want: nil},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "binding with reuseport and device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "mixing reuseport and not reuseport by binding to device",
@@ -177,14 +178,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind and release",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
// Release the bind to device 0 and try again.
@@ -195,7 +196,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind twice with reuseport once",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "release an unreserved device",
@@ -213,16 +214,16 @@ func TestPortReservation(t *testing.T) {
tname: "bind with reuseaddr",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
},
}, {
tname: "bind twice with reuseaddr once",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport",
@@ -236,14 +237,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport, and then reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
@@ -264,14 +265,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseport, and then reuseaddr and reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
@@ -283,7 +284,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind tuple with reuseaddr, and then wildcard",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
@@ -295,7 +296,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind tuple with reuseaddr, and then wildcard",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind two tuples with reuseaddr",
@@ -313,7 +314,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind wildcard, and then tuple with reuseaddr",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind wildcard twice with reuseaddr",
@@ -333,8 +334,8 @@ func TestPortReservation(t *testing.T) {
continue
}
gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
- if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
+ if diff := cmp.Diff(test.want, err); diff != "" {
+ t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -345,30 +346,29 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- customErr := &tcpip.Error{}
for _, test := range []struct {
name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
+ f func(port uint16) (bool, tcpip.Error)
+ wantErr tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
{
name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
+ f: func(port uint16) (bool, tcpip.Error) {
+ return false, &tcpip.ErrBadBuffer{}
},
- wantErr: customErr,
+ wantErr: &tcpip.ErrBadBuffer{},
},
{
name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port == FirstEphemeral+42 {
return true, nil
}
@@ -378,49 +378,52 @@ func TestPickEphemeralPort(t *testing.T) {
},
{
name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port < FirstEphemeral {
return true, nil
}
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
- if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ port, err := pm.PickEphemeralPort(test.f)
+ if diff := cmp.Diff(test.wantErr, err); diff != "" {
+ t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
+ }
+ if port != test.wantPort {
+ t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
}
})
}
}
func TestPickEphemeralPortStable(t *testing.T) {
- customErr := &tcpip.Error{}
for _, test := range []struct {
name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
+ f func(port uint16) (bool, tcpip.Error)
+ wantErr tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
{
name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
+ f: func(port uint16) (bool, tcpip.Error) {
+ return false, &tcpip.ErrBadBuffer{}
},
- wantErr: customErr,
+ wantErr: &tcpip.ErrBadBuffer{},
},
{
name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port == FirstEphemeral+42 {
return true, nil
}
@@ -430,20 +433,24 @@ func TestPickEphemeralPortStable(t *testing.T) {
},
{
name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port < FirstEphemeral {
return true, nil
}
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
- if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ port, err := pm.PickEphemeralPortStable(portOffset, test.f)
+ if diff := cmp.Diff(test.wantErr, err); diff != "" {
+ t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
+ }
+ if port != test.wantPort {
+ t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
}
})
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index cf0a5fefe..db9b91815 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,7 +8,6 @@ go_binary(
visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3b4f900e3..856ea998d 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -41,7 +41,7 @@
package main
import (
- "bufio"
+ "bytes"
"fmt"
"log"
"math/rand"
@@ -51,7 +51,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
@@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
close(ch)
}()
- r := bufio.NewReader(os.Stdin)
- for {
- v := buffer.NewView(1024)
- n, err := r.Read(v)
- if err != nil {
- return
- }
-
- v.CapLength(n)
- for len(v) > 0 {
- n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- if err != nil {
- fmt.Println("Write failed:", err)
- return
+ var b bytes.Buffer
+ if err := func() error {
+ for {
+ if _, err := b.ReadFrom(os.Stdin); err != nil {
+ return fmt.Errorf("b.ReadFrom failed: %w", err)
}
- v.TrimFront(int(n))
+ for b.Len() != 0 {
+ if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil {
+ return fmt.Errorf("ep.Write failed: %s", err)
+ }
+ }
}
+ }(); err != nil {
+ fmt.Println(err)
}
}
@@ -179,7 +175,7 @@ func main() {
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventOut)
terr := ep.Connect(remote)
- if terr == tcpip.ErrConnectStarted {
+ if _, ok := terr.(*tcpip.ErrConnectStarted); ok {
fmt.Println("Connect is pending...")
<-notifyCh
terr = ep.LastError()
@@ -202,11 +198,11 @@ func main() {
for {
_, err := ep.Read(os.Stdout, tcpip.ReadOptions{})
if err != nil {
- if err == tcpip.ErrClosedForReceive {
+ if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
break
}
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 3ac562756..9b23df3a9 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,6 +20,7 @@
package main
import (
+ "bytes"
"flag"
"io"
"log"
@@ -50,7 +51,7 @@ type endpointWriter struct {
}
type tcpipError struct {
- inner *tcpip.Error
+ inner tcpip.Error
}
func (e *tcpipError) Error() string {
@@ -58,7 +59,9 @@ func (e *tcpipError) Error() string {
}
func (e *endpointWriter) Write(p []byte) (int, error) {
- n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(p)
+ n, err := e.ep.Write(&r, tcpip.WriteOptions{})
if err != nil {
return int(n), &tcpipError{
inner: err,
@@ -86,7 +89,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
for {
_, err := ep.Read(&w, tcpip.ReadOptions{})
if err != nil {
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
@@ -214,7 +217,7 @@ func main() {
for {
n, wq, err := ep.Accept(nil)
if err != nil {
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index f3ad40fdf..019d6a63c 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,11 +15,16 @@
package tcpip
import (
+ "math"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
)
+// PacketOverheadFactor is used to multiply the value provided by the user on a
+// SetSockOpt for setting the send/receive buffer sizes sockets.
+const PacketOverheadFactor = 2
+
// SocketOptionsHandler holds methods that help define endpoint specific
// behavior for socket level socket options. These must be implemented by
// endpoints to get notified when socket level options are set.
@@ -41,13 +46,19 @@ type SocketOptionsHandler interface {
OnCorkOptionSet(v bool)
// LastError is invoked when SO_ERROR is read for an endpoint.
- LastError() *Error
+ LastError() Error
// UpdateLastError updates the endpoint specific last error field.
- UpdateLastError(err *Error)
+ UpdateLastError(err Error)
// HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
HasNIC(v int32) bool
+
+ // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE.
+ GetSendBufferSize() (int64, Error)
+
+ // IsUnixSocket is invoked to check if the socket is of unix domain.
+ IsUnixSocket() bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -72,18 +83,39 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
// LastError implements SocketOptionsHandler.LastError.
-func (*DefaultSocketOptionsHandler) LastError() *Error {
+func (*DefaultSocketOptionsHandler) LastError() Error {
return nil
}
// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
-func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {}
// HasNIC implements SocketOptionsHandler.HasNIC.
func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
return false
}
+// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize.
+func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) {
+ return 0, nil
+}
+
+// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket.
+func (*DefaultSocketOptionsHandler) IsUnixSocket() bool {
+ return false
+}
+
+// StackHandler holds methods to access the stack options. These must be
+// implemented by the stack.
+type StackHandler interface {
+ // Option allows retrieving stack wide options.
+ Option(option interface{}) Error
+
+ // TransportProtocolOption allows retrieving individual protocol level
+ // option values.
+ TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error
+}
+
// SocketOptions contains all the variables which store values for SOL_SOCKET,
// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
@@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
type SocketOptions struct {
handler SocketOptionsHandler
+ // StackHandler is initialized at the creation time and will not change.
+ stackHandler StackHandler `state:"manual"`
+
// These fields are accessed and modified using atomic operations.
// broadcastEnabled determines whether datagram sockets are allowed to
@@ -170,6 +205,14 @@ type SocketOptions struct {
// bindToDevice determines the device to which the socket is bound.
bindToDevice int32
+ // getSendBufferLimits provides the handler to get the min, default and
+ // max size for send buffer. It is initialized at the creation time and
+ // will not change.
+ getSendBufferLimits GetSendBufferLimits `state:"manual"`
+
+ // sendBufferSize determines the send buffer size for this socket.
+ sendBufferSize int64
+
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -180,8 +223,10 @@ type SocketOptions struct {
// InitHandler initializes the handler. This must be called before using the
// socket options utility.
-func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) {
so.handler = handler
+ so.stackHandler = stack
+ so.getSendBufferLimits = getSendBufferLimits
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -193,7 +238,7 @@ func storeAtomicBool(addr *uint32, v bool) {
}
// SetLastError sets the last error for a socket.
-func (so *SocketOptions) SetLastError(err *Error) {
+func (so *SocketOptions) SetLastError(err Error) {
so.handler.UpdateLastError(err)
}
@@ -378,7 +423,7 @@ func (so *SocketOptions) SetRecvError(v bool) {
}
// GetLastError gets value for SO_ERROR option.
-func (so *SocketOptions) GetLastError() *Error {
+func (so *SocketOptions) GetLastError() Error {
return so.handler.LastError()
}
@@ -435,7 +480,7 @@ type SockError struct {
sockErrorEntry
// Err is the error caused by the errant packet.
- Err *Error
+ Err Error
// ErrOrigin indicates the error origin.
ErrOrigin SockErrOrigin
// ErrType is the type in the ICMP header.
@@ -493,7 +538,7 @@ func (so *SocketOptions) QueueErr(err *SockError) {
}
// QueueLocalErr queues a local error onto the local queue.
-func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
so.QueueErr(&SockError{
Err: err,
ErrOrigin: SockExtErrorOriginLocal,
@@ -510,11 +555,52 @@ func (so *SocketOptions) GetBindToDevice() int32 {
}
// SetBindToDevice sets value for SO_BINDTODEVICE option.
-func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
if !so.handler.HasNIC(bindToDevice) {
- return ErrUnknownDevice
+ return &ErrUnknownDevice{}
}
atomic.StoreInt32(&so.bindToDevice, bindToDevice)
return nil
}
+
+// GetSendBufferSize gets value for SO_SNDBUF option.
+func (so *SocketOptions) GetSendBufferSize() (int64, Error) {
+ if so.handler.IsUnixSocket() {
+ return so.handler.GetSendBufferSize()
+ }
+ return atomic.LoadInt64(&so.sendBufferSize), nil
+}
+
+// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
+// stack handler should be invoked to set the send buffer size.
+func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
+ if so.handler.IsUnixSocket() {
+ return
+ }
+
+ v := sendBufferSize
+ if notify {
+ // TODO(b/176170271): Notify waiters after size has grown.
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ ss := so.getSendBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ // Multiply it by factor of 2.
+ if v > max {
+ v = max
+ }
+
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
+ }
+ } else {
+ v = math.MaxInt32
+ }
+ }
+ atomic.StoreInt64(&so.sendBufferSize, v)
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index cd423bf71..e5590ecc0 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
@@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
// AddAndAcquireTemporaryAddress adds a temporary address.
//
-// Returns tcpip.ErrDuplicateAddress if the address exists.
+// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// The temporary address's endpoint is acquired and returned.
-func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
@@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// If the addressable endpoint already has the address in a non-permanent state,
// and addAndAcquireAddressLocked is adding a permanent address, that address is
// promoted in place and its properties set to the properties provided. If the
-// address already exists in any other state, then tcpip.ErrDuplicateAddress is
+// address already exists in any other state, then *tcpip.ErrDuplicateAddress is
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We are adding a non-permanent address but the address exists. No need
// to go any further since we can only promote existing temporary/expired
// addresses to permanent.
- return nil, tcpip.ErrDuplicateAddress
+ return nil, &tcpip.ErrDuplicateAddress{}
}
addrState.mu.Lock()
@@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState.mu.Unlock()
// We are adding a permanent address but a permanent address already
// exists.
- return nil, tcpip.ErrDuplicateAddress
+ return nil, &tcpip.ErrDuplicateAddress{}
}
if addrState.mu.refs == 0 {
@@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// RemovePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentAddressLocked(addr)
@@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t
// requirements.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error {
addrState, ok := a.mu.endpoints[addr]
if !ok {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
return a.removePermanentEndpointLocked(addrState)
@@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre
// RemovePermanentEndpoint removes the passed endpoint if it is associated with
// a and permanent.
-func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error {
+func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error {
addrState, ok := ep.(*addressState)
if !ok || addrState.addressableEndpointState != a {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
a.mu.Lock()
@@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
// requirements.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error {
+func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error {
if !addrState.GetKind().IsPermanent() {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
addrState.SetKind(PermanentExpired)
@@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() {
defer a.mu.Unlock()
for _, ep := range a.mu.endpoints {
- // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
+ // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
- if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := a.removePermanentEndpointLocked(ep); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
}
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 5e649cca6..54617f2e6 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -198,15 +198,15 @@ type bucket struct {
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
+func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, tcpip.ErrUnknownProtocol
+ return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, tcpip.ErrUnknownProtocol
+ return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
@@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
conn, _ := ct.connForTID(tid)
if conn == nil {
// Not a tracked connection.
- return "", 0, tcpip.ErrNotConnected
+ return "", 0, &tcpip.ErrNotConnected{}
} else if conn.manip == manipNone {
// Unmanipulated connection.
- return "", 0, tcpip.ErrInvalidOptionValue
+ return "", 0, &tcpip.ErrInvalidOptionValue{}
}
return conn.original.dstAddr, conn.original.dstPort, nil
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4908848e9..63a42a2ea 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -41,6 +41,8 @@ const (
protocolNumberOffset = 2
)
+var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
@@ -53,9 +55,7 @@ type fwdTestNetworkEndpoint struct {
dispatcher TransportDispatcher
}
-var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
-
-func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
return nil
}
@@ -104,7 +104,7 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
-func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
@@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
@@ -124,14 +124,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error {
// The network header should not already be populated.
if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
@@ -141,6 +141,21 @@ func (f *fwdTestNetworkEndpoint) Close() {
f.AddressableEndpointState.Cleanup()
}
+// Stats implements stack.NetworkEndpoint.
+func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats {
+ return &fwdTestNetworkEndpointStats{}
+}
+
+var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil)
+
+type fwdTestNetworkEndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {}
+
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
+
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
@@ -158,18 +173,15 @@ type fwdTestNetworkProtocol struct {
}
}
-var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
-var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
-
-func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
-func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
return fwdTestNetDefaultPrefixLen
}
@@ -195,19 +207,19 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress
return e
}
-func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
-func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
@@ -307,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -323,7 +335,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.WritePacket(r, gso, protocol, pkt)
@@ -356,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
UseNeighborCache: useNeighborCache,
})
- if !useNeighborCache {
- proto.addrCache = s.linkAddrCache
- }
-
// Enable forwarding.
s.SetForwarding(proto.Number(), true)
@@ -389,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
t.Fatal("AddAddress #2 failed:", err)
}
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("NIC 2 does not exist")
+ }
if useNeighborCache {
// Control the neighbor cache for NIC 2.
- nic, ok := s.nics[2]
- if !ok {
- t.Fatal("failed to get the neighbor cache for NIC 2")
- }
proto.neigh = nic.neigh
+ } else {
+ proto.addrCache = nic.linkAddrCache
}
// Route all packets to NIC 2.
@@ -481,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -607,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
}
},
},
@@ -692,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -768,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -858,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 09c7811fa..63832c200 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
// ReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided.
-func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -267,11 +267,11 @@ const (
// dropped.
//
// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
-// which address and nicName can be gathered. Currently, address is only
-// needed for prerouting and nicName is only needed for output.
+// which address can be gathered. Currently, address is only needed for
+// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(pkt, hook, nicName) {
+ if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
@@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
- return "", 0, tcpip.ErrNotConnected
+ return "", 0, &tcpip.ErrNotConnected{}
}
return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 56a3e7861..fd9d61e39 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -210,8 +210,19 @@ type IPHeaderFilter struct {
// filter will match packets that fail the source comparison.
SrcInvert bool
- // OutputInterface matches the name of the outgoing interface for the
- // packet.
+ // InputInterface matches the name of the incoming interface for the packet.
+ InputInterface string
+
+ // InputInterfaceMask masks the characters of the interface name when
+ // comparing with InputInterface.
+ InputInterfaceMask string
+
+ // InputInterfaceInvert inverts the meaning of incoming interface check,
+ // i.e. when true the filter will match packets that fail the incoming
+ // interface comparison.
+ InputInterfaceInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
@@ -228,7 +239,7 @@ type IPHeaderFilter struct {
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
-func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
// TODO(gvisor.dev/issue/170): Support other filter fields.
@@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return false
}
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(fl.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which
- // begins with the name should be matched.
- ifName := fl.OutputInterface
- matches := nicName == ifName
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- }
- return fl.OutputInterfaceInvert != matches
+ switch hook {
+ case Prerouting, Input:
+ return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
+ case Output:
+ return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
+ case Forward, Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ return true
+ default:
+ panic(fmt.Sprintf("unknown hook: %d", hook))
}
+}
- return true
+func matchIfName(nicName string, ifName string, invert bool) bool {
+ n := len(ifName)
+ if n == 0 {
+ // If the interface name is omitted in the filter, any interface will match.
+ return true
+ }
+ // If the interface name ends with '+', any interface which begins with the
+ // name should be matched.
+ var matches bool
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
@@ -320,7 +340,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index b600a1cab..930b8f795 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -24,12 +24,16 @@ import (
const linkAddrCacheSize = 512 // max cache entries
+var _ LinkAddressCache = (*linkAddrCache)(nil)
+
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
//
// This struct is safe for concurrent use.
type linkAddrCache struct {
+ nic *NIC
+
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
@@ -41,9 +45,9 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- cache struct {
+ mu struct {
sync.Mutex
- table map[tcpip.FullAddress]*linkAddrEntry
+ table map[tcpip.Address]*linkAddrEntry
lru linkAddrEntryList
}
}
@@ -77,31 +81,42 @@ type linkAddrEntry struct {
// linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
- // TODO(gvisor.dev/issue/5150): move these fields under mu.
- // mu protects the fields below.
- mu sync.RWMutex
+ cache *linkAddrCache
- addr tcpip.FullAddress
- linkAddr tcpip.LinkAddress
- expiration time.Time
- s entryState
+ mu struct {
+ sync.RWMutex
- // done is closed when address resolution is complete. It is nil iff s is
- // incomplete and resolution is not yet in progress.
- done chan struct{}
+ addr tcpip.Address
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
- // onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
+ done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(LinkResolutionResult)
+ }
}
func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
- for _, callback := range e.onResolve {
- callback(linkAddr, len(linkAddr) != 0)
+ res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0}
+ for _, callback := range e.mu.onResolve {
+ callback(res)
}
- e.onResolve = nil
- if ch := e.done; ch != nil {
+ e.mu.onResolve = nil
+ if ch := e.mu.done; ch != nil {
close(ch)
- e.done = nil
+ e.mu.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0)
}
}
@@ -114,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
//
// Precondition: e.mu must be locked
func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
- if e.s == incomplete && ns == ready {
- e.notifyCompletionLocked(e.linkAddr)
+ if e.mu.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.mu.linkAddr)
}
- if expiration.IsZero() || expiration.After(e.expiration) {
- e.expiration = expiration
+ if expiration.IsZero() || expiration.After(e.mu.expiration) {
+ e.mu.expiration = expiration
}
- e.s = ns
+ e.mu.s = ns
}
// add adds a k -> v mapping to the cache.
-func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
// relative to the time when information was learned, rather than when it
// happened to be inserted into the cache.
expiration := time.Now().Add(c.ageLimit)
- c.cache.Lock()
+ c.mu.Lock()
entry := c.getOrCreateEntryLocked(k)
- c.cache.Unlock()
-
entry.mu.Lock()
defer entry.mu.Unlock()
- entry.linkAddr = v
+ c.mu.Unlock()
+
+ entry.mu.linkAddr = v
entry.changeStateLocked(ready, expiration)
}
@@ -150,19 +165,19 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
- if entry, ok := c.cache.table[k]; ok {
- c.cache.lru.Remove(entry)
- c.cache.lru.PushFront(entry)
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
+ if entry, ok := c.mu.table[k]; ok {
+ c.mu.lru.Remove(entry)
+ c.mu.lru.PushFront(entry)
return entry
}
var entry *linkAddrEntry
- if len(c.cache.table) == linkAddrCacheSize {
- entry = c.cache.lru.Back()
+ if len(c.mu.table) == linkAddrCacheSize {
+ entry = c.mu.lru.Back()
entry.mu.Lock()
- delete(c.cache.table, entry.addr)
- c.cache.lru.Remove(entry)
+ delete(c.mu.table, entry.mu.addr)
+ c.mu.lru.Remove(entry)
// Wake waiters and mark the soon-to-be-reused entry as expired.
entry.notifyCompletionLocked("" /* linkAddr */)
@@ -172,53 +187,56 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
*entry = linkAddrEntry{
- addr: k,
- s: incomplete,
+ cache: c,
}
- c.cache.table[k] = entry
- c.cache.lru.PushFront(entry)
+ entry.mu.Lock()
+ entry.mu.addr = k
+ entry.mu.s = incomplete
+ entry.mu.Unlock()
+ c.mu.table[k] = entry
+ c.mu.lru.PushFront(entry)
return entry
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
- c.cache.Lock()
- defer c.cache.Unlock()
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+ c.mu.Lock()
entry := c.getOrCreateEntryLocked(k)
entry.mu.Lock()
defer entry.mu.Unlock()
+ c.mu.Unlock()
- switch s := entry.s; s {
+ switch s := entry.mu.s; s {
case ready:
- if !time.Now().After(entry.expiration) {
+ if !time.Now().After(entry.mu.expiration) {
// Not expired.
if onResolve != nil {
- onResolve(entry.linkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true})
}
- return entry.linkAddr, nil, nil
+ return entry.mu.linkAddr, nil, nil
}
entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
if onResolve != nil {
- entry.onResolve = append(entry.onResolve, onResolve)
+ entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
}
- if entry.done == nil {
- entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ if entry.mu.done == nil {
+ entry.mu.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
- return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
+ linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
@@ -234,10 +252,10 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has
// succeeded and mark the entry accordingly. Returns true if request can stop,
// false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
- c.cache.Lock()
- defer c.cache.Unlock()
- entry, ok := c.cache.table[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ entry, ok := c.mu.table[k]
if !ok {
// Entry was evicted from the cache.
return true
@@ -245,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
entry.mu.Lock()
defer entry.mu.Unlock()
- switch s := entry.s; s {
+ switch s := entry.mu.s; s {
case ready:
// Entry was made ready by resolver.
case incomplete:
@@ -255,19 +273,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
}
// Max number of retries reached, delete entry.
entry.notifyCompletionLocked("" /* linkAddr */)
- delete(c.cache.table, k)
+ delete(c.mu.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
return true
}
-func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
c := &linkAddrCache{
+ nic: nic,
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
}
- c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize)
return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d7ac6cf5f..466a5e8d9 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -26,7 +26,7 @@ import (
)
type testaddr struct {
- addr tcpip.FullAddress
+ addr tcpip.Address
linkAddr tcpip.LinkAddress
}
@@ -35,7 +35,7 @@ var testAddrs = func() []testaddr {
for i := 0; i < 4*linkAddrCacheSize; i++ {
addr := fmt.Sprintf("Addr%06d", i)
addrs = append(addrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ addr: tcpip.Address(addr),
linkAddr: tcpip.LinkAddress("Link" + addr),
})
}
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
// TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
@@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
for _, ta := range testAddrs {
- if ta.addr.Addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
+ if ta.addr == addr {
+ r.cache.AddLinkAddress(ta.addr, ta.linkAddr)
break
}
}
@@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return 1
}
-func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) {
var attemptedResolution bool
for {
got, ch, err := c.get(addr, linkRes, "", nil, nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
if attemptedResolution {
- return got, tcpip.ErrTimeout
+ return got, &tcpip.ErrTimeout{}
}
attemptedResolution = true
<-ch
@@ -93,17 +93,23 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
+func newEmptyNIC() *NIC {
+ n := &NIC{}
+ n.linkResQueue.init(n)
+ return n
+}
+
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
@@ -111,25 +117,25 @@ func TestCacheOverflow(t *testing.T) {
e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
- c.cache.Lock()
- defer c.cache.Unlock()
+ c.mu.Lock()
+ defer c.mu.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ if entry, ok := c.mu.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
@@ -137,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) {
wg.Add(1)
go func() {
for _, e := range testAddrs {
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
}
wg.Done()
}()
@@ -150,52 +156,53 @@ func TestCacheConcurrent(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
e = testAddrs[0]
- c.cache.Lock()
- defer c.cache.Unlock()
- if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if entry, ok := c.mu.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
e := testAddrs[0]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
+ _, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err)
}
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
e := testAddrs[0]
l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
- c.add(e.addr, l2)
+ c.AddLinkAddress(e.addr, l2)
got, _, err = c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != l2 {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2)
}
}
@@ -206,15 +213,15 @@ func TestCacheResolution(t *testing.T) {
//
// Using a large resolution timeout decreases the probability of experiencing
// this race condition and does not affect how long this test takes to run.
- c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err)
}
if got != ta.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
}
}
@@ -223,16 +230,16 @@ func TestCacheResolution(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
}
}
func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
var requestCount uint32
@@ -244,17 +251,18 @@ func TestCacheResolutionFailed(t *testing.T) {
e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("getBlocking(_, %s, _): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr)
}
before := atomic.LoadUint32(&requestCount)
- e.addr.Addr += "2"
- if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ e.addr += "2"
+ a, err := getBlocking(c, e.addr, linkRes)
+ if _, ok := err.(*tcpip.ErrTimeout); !ok {
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -265,11 +273,12 @@ func TestCacheResolutionFailed(t *testing.T) {
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 500 * time.Millisecond
expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
- if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ a, err := getBlocking(c, e.addr, linkRes)
+ if _, ok := err.(*tcpip.ErrTimeout); !ok {
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 61636cae5..64383bc7c 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -45,6 +45,8 @@ const (
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+ defaultPrefixLen = 128
+
// Extra time to use when waiting for an async event to occur.
defaultAsyncPositiveEventTimeout = 10 * time.Second
@@ -102,7 +104,7 @@ type ndpDADEvent struct {
nicID tcpip.NICID
addr tcpip.Address
resolved bool
- err *tcpip.Error
+ err tcpip.Error
}
type ndpRouterEvent struct {
@@ -172,7 +174,7 @@ type ndpDispatcher struct {
}
// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
nicID,
@@ -309,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string {
return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
}
@@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Should get the address immediately since we should not have performed
@@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) {
default:
t.Fatal("expected DAD event")
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatal(err)
}
// We should not have sent any NDP NS messages.
@@ -440,31 +442,31 @@ func TestDADResolve(t *testing.T) {
NIC: nicID,
}})
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Make sure the address does not resolve before the resolution time has
// passed.
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
}
// Should not get a route even if we specify the local address as the
// tentative address.
{
r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
- if err != tcpip.ErrNoRoute {
- t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
}
if r != nil {
r.Release()
@@ -472,8 +474,8 @@ func TestDADResolve(t *testing.T) {
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
- if err != tcpip.ErrNoRoute {
- t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
}
if r != nil {
r.Release()
@@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if addr.Address != addr1 {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Error(err)
}
// Should get a route using the address now that it is resolved.
{
@@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) {
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Receive a packet to simulate an address conflict.
@@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Attempting to add the address again should not fail if the address's
@@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
test.stopFn(t, s)
@@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) {
}
if !test.skipFinalAddrCheck {
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
}
@@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Add addresses for each NIC.
- if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
}
- if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
+ addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
}
expectDADEvent(nicID2, addr2)
- if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
+ addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
}
expectDADEvent(nicID3, addr3)
// Address should not be considered bound to NIC(1) yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Should get the address on NIC(2) and NIC(3)
@@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) {
// it as the stack was configured to not do DAD by
// default and we only updated the NDP configurations on
// NIC(1).
- addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
+ if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatal(err)
}
- addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
+ if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatal(err)
}
// Sleep until right (500ms before) before resolution to
// make sure the address didn't resolve on NIC(1) yet.
const delta = 500 * time.Millisecond
time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2808,6 +2775,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2827,10 +2795,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
Gateway: llAddr3,
NIC: nicID,
}})
+
if useNeighborCache {
- s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
} else {
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
}
return ndpDisp, e, s
}
@@ -2940,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3088,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3238,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
@@ -3255,8 +3222,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
defer ep.Close()
ep.SocketOptions().SetV6Only(true)
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ {
+ err := ep.Connect(dstAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{})
+ }
}
})
}
@@ -3615,10 +3585,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index acee72572..88a3ff776 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// packet prompting NUD/link address resolution.
//
// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
@@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
- onResolve(entry.neigh.LinkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true})
}
return entry.neigh, nil, nil
case Unknown, Incomplete, Failed:
@@ -154,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
entry.handlePacketQueuedLocked(localAddr)
- return entry.neigh, entry.done, tcpip.ErrWouldBlock
+ return entry.neigh, entry.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
@@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li
// no matching entry for the remote address.
}
-// HandleUpperLevelConfirmation implements
-// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
-// RFC 4861 section 7.3.1.
-func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+// handleUpperLevelConfirmation processes a confirmation of reachablity from
+// some protocol that operates at a layer above the IP/link layer.
+func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
n.mu.RLock()
entry, ok := n.cache[addr]
n.mu.RUnlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b96a56612..2870e4f66 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -194,7 +194,7 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
if !r.dropReplies {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
@@ -251,8 +251,8 @@ func TestNeighborCacheGetConfig(t *testing.T) {
// No events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -273,8 +273,8 @@ func TestNeighborCacheSetConfig(t *testing.T) {
// No events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -295,8 +295,9 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -321,11 +322,11 @@ func TestNeighborCacheEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
@@ -335,8 +336,8 @@ func TestNeighborCacheEntry(t *testing.T) {
// No more events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -359,8 +360,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -385,11 +387,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
neigh.removeEntry(entry.Addr)
@@ -407,15 +409,18 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ {
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ }
}
}
@@ -462,8 +467,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -506,11 +512,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
})
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -531,15 +537,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
// No more events should have been dispatched.
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
return nil
@@ -579,8 +585,9 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -603,11 +610,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the entry
@@ -626,11 +633,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -668,11 +675,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the static entry that was just added
@@ -681,8 +688,8 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// No more events should have been dispatched.
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -712,11 +719,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Add a duplicate entry with a different link address
@@ -736,8 +743,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
}
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
}
@@ -774,11 +781,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the static entry that was just added
@@ -796,11 +803,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -830,8 +837,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -854,11 +862,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Override the entry with a static one using the same address
@@ -886,11 +894,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -932,8 +940,8 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Static,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -948,11 +956,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
opts := overflowOptions{
@@ -989,8 +997,9 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -1014,11 +1023,11 @@ func TestNeighborCacheClear(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Add a static entry.
@@ -1037,11 +1046,11 @@ func TestNeighborCacheClear(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1072,8 +1081,8 @@ func TestNeighborCacheClear(t *testing.T) {
}
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1094,8 +1103,9 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1118,11 +1128,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Clear the cache.
@@ -1140,11 +1150,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1188,16 +1198,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1225,11 +1232,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1247,16 +1254,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1299,11 +1303,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1331,15 +1335,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
// No more events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1366,8 +1370,10 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ default:
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{})
}
}(entry)
}
@@ -1398,8 +1404,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
}
@@ -1423,16 +1429,13 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1455,8 +1458,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
@@ -1489,8 +1492,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: updatedLinkAddr,
State: Delay,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1507,8 +1510,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: updatedLinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
}
@@ -1539,16 +1542,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
// First, sanity check that resolution is working
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1567,8 +1567,8 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
// Verify address resolution fails for an unknown address.
@@ -1576,19 +1576,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry.Addr += "2"
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1627,19 +1621,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1674,19 +1662,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Perform address resolution with a faulty link, which will fail.
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1713,13 +1695,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Retry address resolution with a working link.
linkRes.dropReplies = false
{
- incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
if incompleteEntry.State != Incomplete {
t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
@@ -1772,16 +1754,13 @@ func BenchmarkCacheClear(b *testing.B) {
b.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- b.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
select {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 75afb3001..53ac9bb6e 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -96,7 +96,7 @@ type neighborEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
isRouter bool
job *tcpip.Job
@@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded}
for _, callback := range e.onResolve {
- callback(e.neigh.LinkAddr, succeeded)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index ec34ffa5a..140b8ca00 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
p := entryTestProbeInfo{
RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
@@ -266,16 +266,16 @@ func TestEntryInitiallyUnknown(t *testing.T) {
// No probes should have been sent.
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
// No events should have been dispatched.
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -299,16 +299,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
// No probes should have been sent.
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
// No events should have been dispatched.
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -333,10 +333,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -352,10 +352,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
}
{
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
}
@@ -374,10 +374,10 @@ func TestEntryUnknownToStale(t *testing.T) {
// No probes should have been sent.
runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -392,8 +392,8 @@ func TestEntryUnknownToStale(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -427,11 +427,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -453,10 +453,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -483,8 +483,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -515,10 +515,10 @@ func TestEntryIncompleteToReachable(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -553,8 +553,8 @@ func TestEntryIncompleteToReachable(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -579,10 +579,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -620,8 +620,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -646,10 +646,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -684,8 +684,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -710,10 +710,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -744,8 +744,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -785,10 +785,10 @@ func TestEntryIncompleteToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -812,8 +812,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -850,10 +850,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -903,8 +903,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -932,10 +932,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -977,8 +977,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1005,10 +1005,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1054,8 +1054,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -1083,10 +1083,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1134,8 +1134,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1157,10 +1157,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1212,8 +1212,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1235,10 +1235,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1290,8 +1290,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1313,10 +1313,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1358,8 +1358,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1381,10 +1381,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1439,8 +1439,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1462,10 +1462,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1520,8 +1520,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1543,10 +1543,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1601,8 +1601,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1624,10 +1624,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1678,8 +1678,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1701,10 +1701,10 @@ func TestEntryStaleToDelay(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1752,8 +1752,8 @@ func TestEntryStaleToDelay(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1780,10 +1780,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1851,8 +1851,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1880,10 +1880,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1958,8 +1958,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1987,10 +1987,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2065,8 +2065,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2088,10 +2088,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2147,8 +2147,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2170,10 +2170,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2231,8 +2231,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2254,10 +2254,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2319,8 +2319,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2343,11 +2343,11 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2372,10 +2372,10 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2418,8 +2418,8 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -2448,11 +2448,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2474,10 +2474,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2539,8 +2539,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2563,11 +2563,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2589,10 +2589,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2658,8 +2658,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2682,11 +2682,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2709,10 +2709,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2772,8 +2772,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2806,10 +2806,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2878,8 +2878,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2907,11 +2907,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2933,10 +2933,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3015,8 +3015,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3044,11 +3044,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3070,10 +3070,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3149,8 +3149,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3178,11 +3178,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3204,10 +3204,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3283,8 +3283,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3309,11 +3309,11 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3336,11 +3336,11 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff)
}
e.mu.Lock()
@@ -3406,8 +3406,8 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3449,10 +3449,10 @@ func TestEntryFailedToIncomplete(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -3498,8 +3498,8 @@ func TestEntryFailedToIncomplete(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0f545f255..e56a624fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -53,6 +53,8 @@ type NIC struct {
// complete.
linkResQueue packetsPendingLinkResolution
+ linkAddrCache *linkAddrCache
+
mu struct {
sync.RWMutex
spoofing bool
@@ -138,7 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.linkResQueue.init()
+ nic.linkResQueue.init(nic)
+ nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
@@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = new(packetEndpointList)
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic)
}
nic.LinkEndpoint.Attach(nic)
@@ -228,7 +231,9 @@ func (n *NIC) disableLocked() {
//
// This matches linux's behaviour at the time of writing:
// https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
- if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported {
+ switch err := n.clearNeighbors(); err.(type) {
+ case nil, *tcpip.ErrNotSupported:
+ default:
panic(fmt.Sprintf("n.clearNeighbors(): %s", err))
}
@@ -243,7 +248,7 @@ func (n *NIC) disableLocked() {
// address (ff02::1), start DAD for permanent addresses, and start soliciting
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
-func (n *NIC) enable() *tcpip.Error {
+func (n *NIC) enable() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -263,7 +268,7 @@ func (n *NIC) enable() *tcpip.Error {
// remove detaches NIC from the link endpoint and releases network endpoint
// resources. This guarantees no packets between this NIC and the network
// stack.
-func (n *NIC) remove() *tcpip.Error {
+func (n *NIC) remove() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -299,48 +304,69 @@ func (n *NIC) IsLoopback() bool {
}
// WritePacket implements NetworkLinkEndpoint.
-func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
- // As per relevant RFCs, we should queue packets while we wait for link
- // resolution to complete.
- //
- // RFC 1122 section 2.3.2.2 (for IPv4):
- // The link layer SHOULD save (rather than discard) at least
- // one (the latest) packet of each set of packets destined to
- // the same unresolved IP address, and transmit the saved
- // packet when the address has been resolved.
- //
- // RFC 4861 section 7.2.2 (for IPv6):
- // While waiting for address resolution to complete, the sender MUST, for
- // each neighbor, retain a small queue of packets waiting for address
- // resolution to complete. The queue MUST hold at least one packet, and MAY
- // contain more. However, the number of queued packets per neighbor SHOULD
- // be limited to some small value. When a queue overflows, the new arrival
- // SHOULD replace the oldest entry. Once address resolution completes, the
- // node transmits any queued packets.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r.Acquire()
- n.linkResQueue.enqueue(ch, r, protocol, pkt)
- return nil
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+ return err
+}
+
+func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ switch pkt := pkt.(type) {
+ case *PacketBuffer:
+ if err := n.writePacket(r, gso, protocol, pkt); err != nil {
+ return 0, err
}
- return err
+ return 1, nil
+ case *PacketBufferList:
+ return n.writePackets(r, gso, protocol, *pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
+}
- return n.writePacket(r.Fields(), gso, protocol, pkt)
+func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ routeInfo, _, err := r.resolvedFields(nil)
+ switch err.(type) {
+ case nil:
+ return n.writePacketBuffer(routeInfo, gso, protocol, pkt)
+ case *tcpip.ErrWouldBlock:
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and
+ // MAY contain more. However, the number of queued packets per neighbor
+ // SHOULD be limited to some small value. When a queue overflows, the new
+ // arrival SHOULD replace the oldest entry. Once address resolution
+ // completes, the node transmits any queued packets.
+ return n.linkResQueue.enqueue(r, gso, protocol, pkt)
+ default:
+ return 0, err
+ }
}
// WritePacketToRemote implements NetworkInterface.
-func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
var r RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
return n.writePacket(r, gso, protocol, pkt)
}
-func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
return err
}
@@ -351,10 +377,18 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
}
// WritePackets implements NetworkLinkEndpoint.
-func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
- // is being peformed like WritePacket.
- writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
+}
+
+func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
+ }
+
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
@@ -463,15 +497,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
@@ -535,63 +569,70 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
}
// removeAddress removes an address from n.
-func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error {
for _, ep := range n.networkEndpoints {
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
continue
}
- if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) {
+ case *tcpip.ErrBadLocalAddress:
continue
- } else {
+ default:
return err
}
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
+}
+
+func (n *NIC) confirmReachable(addr tcpip.Address) {
+ if n := n.neigh; n != nil {
+ n.handleUpperLevelConfirmation(addr)
+ }
}
-func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
if n.neigh != nil {
entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve)
return entry.LinkAddr, ch, err
}
- return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve)
+ return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve)
}
-func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
+func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) {
if n.neigh == nil {
- return nil, tcpip.ErrNotSupported
+ return nil, &tcpip.ErrNotSupported{}
}
return n.neigh.entries(), nil
}
-func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
n.neigh.addStaticEntry(addr, linkAddress)
return nil
}
-func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
if !n.neigh.removeEntry(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
return nil
}
-func (n *NIC) clearNeighbors() *tcpip.Error {
+func (n *NIC) clearNeighbors() tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
n.neigh.clear()
@@ -600,7 +641,7 @@ func (n *NIC) clearNeighbors() *tcpip.Error {
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
-func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
@@ -608,12 +649,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
ep, ok := n.networkEndpoints[protocol]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
return gep.JoinGroup(addr)
@@ -621,15 +662,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
ep, ok := n.networkEndpoints[protocol]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
return gep.LeaveGroup(addr)
@@ -879,9 +920,9 @@ func (n *NIC) Name() string {
}
// nudConfigs gets the NUD configurations for n.
-func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
+func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) {
if n.neigh == nil {
- return NUDConfigurations{}, tcpip.ErrNotSupported
+ return NUDConfigurations{}, &tcpip.ErrNotSupported{}
}
return n.neigh.config(), nil
}
@@ -890,22 +931,22 @@ func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
c.resetInvalidFields()
n.neigh.setConfig(c)
return nil
}
-func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
eps, ok := n.mu.packetEPs[netProto]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
eps.add(ep)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5b5c58afb..2f719fbe5 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -39,7 +39,7 @@ type testIPv6Endpoint struct {
invalidatedRtr tcpip.Address
}
-func (*testIPv6Endpoint) Enable() *tcpip.Error {
+func (*testIPv6Endpoint) Enable() tcpip.Error {
return nil
}
@@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
}
// WritePacket implements NetworkEndpoint.WritePacket.
-func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
return nil
}
// WritePackets implements NetworkEndpoint.WritePackets.
-func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
// Our tests don't use this so we don't support it.
- return 0, tcpip.ErrNotSupported
+ return 0, &tcpip.ErrNotSupported{}
}
// WriteHeaderIncludedPacket implements
// NetworkEndpoint.WriteHeaderIncludedPacket.
-func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error {
// Our tests don't use this so we don't support it.
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
@@ -99,11 +99,20 @@ func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.invalidatedRtr = rtr
}
-var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+// Stats implements NetworkEndpoint.
+func (*testIPv6Endpoint) Stats() NetworkEndpointStats {
+ return &testIPv6EndpointStats{}
+}
+
+var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil)
+
+type testIPv6EndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*testIPv6EndpointStats) IsNetworkEndpointStats() {}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
-// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
-// believe it supports IPv6.
-//
// We use this instead of ipv6.protocol because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
type testIPv6Protocol struct{}
@@ -140,12 +149,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache,
}
// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
return nil
}
// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
return nil
}
@@ -160,15 +169,13 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
return 0, false, false
}
-var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
-
// LinkAddressProtocol implements LinkAddressResolver.
func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index 12d67409a..77926e289 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -174,10 +174,6 @@ type NUDHandler interface {
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
-
- // HandleUpperLevelConfirmation processes an incoming upper-level protocol
- // (e.g. TCP acknowledgements) reachability confirmation.
- HandleUpperLevelConfirmation(addr tcpip.Address)
}
// NUDConfigurations is the NUD configurations for the netstack. This is used
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 7bca1373e..ebfd5eb45 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -65,8 +65,9 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
// No NIC with ID 1 yet.
config := stack.NUDConfigurations{}
- if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ err := s.SetNUDConfigurations(1, config)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{})
}
}
@@ -90,8 +91,9 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
- t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ _, err := s.NUDConfigurations(nicID)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{})
}
}
@@ -117,8 +119,9 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
}
config := stack.NUDConfigurations{}
- if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
- t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ err := s.SetNUDConfigurations(nicID, config)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{})
}
}
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 41529ffd5..1c651e216 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -28,108 +28,205 @@ const (
maxPendingPacketsPerResolution = 256
)
+// pendingPacketBuffer is a pending packet buffer.
+//
+// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use
+// WritePackets so we can use a PacketBufferList everywhere.
+type pendingPacketBuffer interface {
+ len() int
+}
+
+func (*PacketBuffer) len() int {
+ return 1
+}
+
+func (p *PacketBufferList) len() int {
+ return p.Len()
+}
+
type pendingPacket struct {
- route *Route
- proto tcpip.NetworkProtocolNumber
- pkt *PacketBuffer
+ routeInfo RouteInfo
+ gso *GSO
+ proto tcpip.NetworkProtocolNumber
+ pkt pendingPacketBuffer
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
- sync.Mutex
+ nic *NIC
- // The packets to send once the resolver completes.
- packets map[<-chan struct{}][]pendingPacket
+ mu struct {
+ sync.Mutex
- // FIFO of channels used to cancel the oldest goroutine waiting for
- // link-address resolution.
- cancelChans []chan struct{}
-}
+ // The packets to send once the resolver completes.
+ //
+ // The link resolution channel is used as the key for this map.
+ packets map[<-chan struct{}][]pendingPacket
-func (f *packetsPendingLinkResolution) init() {
- f.Lock()
- defer f.Unlock()
- f.packets = make(map[<-chan struct{}][]pendingPacket)
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ //
+ // cancelChans holds the same channels that are used as keys to packets.
+ cancelChans []<-chan struct{}
+ }
}
-func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- f.Lock()
- defer f.Unlock()
+func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
+ n := uint64(pkt.len())
+ f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
- packets, ok := f.packets[ch]
- if len(packets) == maxPendingPacketsPerResolution {
- p := packets[0]
- packets[0] = pendingPacket{}
- packets = packets[1:]
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
- p.route.Release()
+ if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
+ ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
}
+}
- if l := len(packets); l >= maxPendingPacketsPerResolution {
- panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+func (f *packetsPendingLinkResolution) init(nic *NIC) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.nic = nic
+ f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
+}
+
+// dequeue any pending packets associated with ch.
+//
+// If success is true, packets will be written and sent to the given remote link
+// address.
+func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) {
+ f.mu.Lock()
+ packets, ok := f.mu.packets[ch]
+ delete(f.mu.packets, ch)
+
+ if ok {
+ for i, cancelChan := range f.mu.cancelChans {
+ if cancelChan == ch {
+ f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
+ break
+ }
+ }
}
- f.packets[ch] = append(packets, pendingPacket{
- route: r,
- proto: proto,
- pkt: pkt,
- })
+ f.mu.Unlock()
if ok {
- return
+ f.dequeuePackets(packets, linkAddr, success)
}
+}
- // Wait for the link-address resolution to complete.
- cancel := f.newCancelChannelLocked()
- go func() {
- cancelled := false
- select {
- case <-ch:
- case <-cancel:
- cancelled = true
- }
+// enqueue a packet to be sent once link resolution completes.
+//
+// If the maximum number of pending resolutions is reached, the packets
+// associated with the oldest link resolution will be dequeued as if they failed
+// link resolution.
+func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ f.mu.Lock()
+ // Make sure we attempt resolution while holding f's lock so that we avoid
+ // a race where link resolution completes before we enqueue the packets.
+ //
+ // A @ T1: Call ResolvedFields (get link resolution channel)
+ // B @ T2: Complete link resolution, dequeue pending packets
+ // C @ T1: Enqueue packet that already completed link resolution (which will
+ // never dequeue)
+ //
+ // To make sure B does not interleave with A and C, we make sure A and C are
+ // done while holding the lock.
+ routeInfo, ch, err := r.resolvedFields(nil)
+ switch err.(type) {
+ case nil:
+ // The route resolved immediately, so we don't need to wait for link
+ // resolution to send the packet.
+ f.mu.Unlock()
+ return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt)
+ case *tcpip.ErrWouldBlock:
+ // We need to wait for link resolution to complete.
+ default:
+ f.mu.Unlock()
+ return 0, err
+ }
- f.Lock()
- packets, ok := f.packets[ch]
- delete(f.packets, ch)
- f.Unlock()
+ defer f.mu.Unlock()
- if !ok {
- panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
- }
+ packets, ok := f.mu.packets[ch]
+ packets = append(packets, pendingPacket{
+ routeInfo: routeInfo,
+ gso: gso,
+ proto: proto,
+ pkt: pkt,
+ })
- for _, p := range packets {
- if cancelled || p.route.IsResolutionRequired() {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
+ if len(packets) > maxPendingPacketsPerResolution {
+ f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
+ packets[0] = pendingPacket{}
+ packets = packets[1:]
- if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
- linkResolvableEP.HandleLinkResolutionFailure(pkt)
- }
- } else {
- p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt)
- }
- p.route.Release()
+ if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
}
- }()
+ }
+
+ f.mu.packets[ch] = packets
+
+ if ok {
+ return pkt.len(), nil
+ }
+
+ cancelledPackets := f.newCancelChannelLocked(ch)
+
+ if len(cancelledPackets) != 0 {
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as handing link resolution failures may be a costly operation.
+ go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */)
+ }
+
+ return pkt.len(), nil
}
-// newCancelChannel creates a channel that can cancel a pending forwarding
-// activity. The oldest channel is closed if the number of open channels would
-// exceed maxPendingResolutions.
-func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
- if len(f.cancelChans) == maxPendingResolutions {
- ch := f.cancelChans[0]
- f.cancelChans[0] = nil
- f.cancelChans = f.cancelChans[1:]
- close(ch)
+// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
+// maximum number of pending resolutions is reached, the oldest channel will be
+// removed and its associated pending packets will be returned.
+func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
+ f.mu.cancelChans = append(f.mu.cancelChans, newCH)
+ if len(f.mu.cancelChans) <= maxPendingResolutions {
+ return nil
}
- if l := len(f.cancelChans); l >= maxPendingResolutions {
+
+ ch := f.mu.cancelChans[0]
+ f.mu.cancelChans[0] = nil
+ f.mu.cancelChans = f.mu.cancelChans[1:]
+ if l := len(f.mu.cancelChans); l > maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
- ch := make(chan struct{})
- f.cancelChans = append(f.cancelChans, ch)
- return ch
+ packets, ok := f.mu.packets[ch]
+ if !ok {
+ panic("must have a packet queue for an uncancelled channel")
+ }
+ delete(f.mu.packets, ch)
+
+ return packets
+}
+
+func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) {
+ for _, p := range packets {
+ if success {
+ p.routeInfo.RemoteLinkAddress = linkAddr
+ _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ } else {
+ f.incrementOutgoingPacketErrors(p.proto, p.pkt)
+
+ if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
+ switch pkt := p.pkt.(type) {
+ case *PacketBuffer:
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
+ case *PacketBufferList:
+ for pb := pkt.Front(); pb != nil; pb = pb.Next() {
+ linkResolvableEP.HandleLinkResolutionFailure(pb)
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
+ }
+ }
+ }
+ }
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 68c113b6a..510da8689 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -172,10 +172,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -184,7 +184,7 @@ type TransportProtocol interface {
// ParsePorts returns the source and destination ports stored in a
// packet of this protocol.
- ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+ ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
// protocol that don't match any existing endpoint. For example,
@@ -197,12 +197,12 @@ type TransportProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error
+ SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error
+ Option(option tcpip.GettableTransportProtocolOption) tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -289,10 +289,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- JoinGroup(group tcpip.Address) *tcpip.Error
+ JoinGroup(group tcpip.Address) tcpip.Error
// LeaveGroup attempts to leave the specified group.
- LeaveGroup(group tcpip.Address) *tcpip.Error
+ LeaveGroup(group tcpip.Address) tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -440,17 +440,17 @@ func (k AddressKind) IsPermanent() bool {
type AddressableEndpoint interface {
// AddAndAcquirePermanentAddress adds the passed permanent address.
//
- // Returns tcpip.ErrDuplicateAddress if the address exists.
+ // Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
//
- // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed
+ // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed
// permanent address.
- RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
+ RemovePermanentAddress(addr tcpip.Address) tcpip.Error
// MainAddress returns the endpoint's primary permanent address.
MainAddress() tcpip.AddressWithPrefix
@@ -512,14 +512,14 @@ type NetworkInterface interface {
Promiscuous() bool
// WritePacketToRemote writes the packet to the given remote link address.
- WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePacket writes a packet with the given protocol through the given
// route.
//
// WritePacket takes ownership of the packet buffer. The packet buffer's
// network and transport header must be set.
- WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol through the given
// route. Must not be called with an empty list of packet buffers.
@@ -529,7 +529,7 @@ type NetworkInterface interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// LinkResolvableNetworkEndpoint handles link resolution events.
@@ -547,8 +547,8 @@ type NetworkEndpoint interface {
// Must only be called when the stack is in a state that allows the endpoint
// to send and receive packets.
//
- // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled.
- Enable() *tcpip.Error
+ // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled.
+ Enable() tcpip.Error
// Enabled returns true if the endpoint is enabled.
Enabled() bool
@@ -574,16 +574,16 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It takes ownership of pkt. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length. It takes ownership of pkts and
// underlying packets.
- WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address. It takes ownership of pkt.
- WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
@@ -597,6 +597,26 @@ type NetworkEndpoint interface {
// NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
// this endpoint.
NetworkProtocolNumber() tcpip.NetworkProtocolNumber
+
+ // Stats returns a reference to the network endpoint stats.
+ Stats() NetworkEndpointStats
+}
+
+// NetworkEndpointStats is the interface implemented by each network endpoint
+// stats struct.
+type NetworkEndpointStats interface {
+ // IsNetworkEndpointStats is an empty method to implement the
+ // NetworkEndpointStats marker interface.
+ IsNetworkEndpointStats()
+}
+
+// IPNetworkEndpointStats is a NetworkEndpointStats that tracks IP-related
+// statistics.
+type IPNetworkEndpointStats interface {
+ NetworkEndpointStats
+
+ // IPStats returns the IP statistics of a network endpoint.
+ IPStats() *tcpip.IPStats
}
// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
@@ -634,12 +654,12 @@ type NetworkProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error
+ SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error
+ Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -776,7 +796,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol and route. Must not be
// called with an empty list of packet buffers.
@@ -786,7 +806,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -801,7 +821,7 @@ type InjectableLinkEndpoint interface {
// link.
//
// dest is used by endpoints with multiple raw destinations.
- InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -813,7 +833,7 @@ type LinkAddressResolver interface {
//
// The request is sent from the passed network interface. If the interface
// local address is unspecified, any interface local address may be used.
- LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
@@ -829,12 +849,8 @@ type LinkAddressResolver interface {
// A LinkAddressCache caches link addresses.
type LinkAddressCache interface {
- // CheckLocalAddress determines if the given local address exists, and if it
- // does not exist.
- CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
-
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress)
}
// RawFactory produces endpoints for writing various types of raw packets.
@@ -842,11 +858,11 @@ type RawFactory interface {
// NewUnassociatedEndpoint produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can
// be used to write arbitrary packets that include the network header.
- NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// NewPacketEndpoint produces endpoints for reading and writing packets
// that include network and (when cooked is false) link layer headers.
- NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
}
// GSOType is the type of GSO segments.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1ff7b3a37..4ae0f2a1a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -86,12 +86,21 @@ type RouteInfo struct {
RemoteLinkAddress tcpip.LinkAddress
}
-// Fields returns a RouteInfo with all of r's exported fields. This allows
-// callers to store the route's fields without retaining a reference to it.
+// Fields returns a RouteInfo with all of the known values for the route's
+// fields.
+//
+// If any fields are unknown (e.g. remote link address when it is waiting for
+// link address resolution), they will be unset.
func (r *Route) Fields() RouteInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fieldsLocked()
+}
+
+func (r *Route) fieldsLocked() RouteInfo {
return RouteInfo{
routeInfo: r.routeInfo,
- RemoteLinkAddress: r.RemoteLinkAddress(),
+ RemoteLinkAddress: r.mu.remoteLinkAddress,
}
}
@@ -306,32 +315,43 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary.
+// ResolvedFieldsResult is the result of a route resolution attempt.
+type ResolvedFieldsResult struct {
+ RouteInfo RouteInfo
+ Success bool
+}
+
+// ResolvedFields attempts to resolve the remote link address if it is not
+// known.
//
-// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
-// waiting for ARP reply). If address resolution is required, a notification
-// channel is also returned for the caller to block on. The channel is closed
-// once address resolution is complete (successful or not). If a callback is
-// provided, it will be called when address resolution is complete, regardless
+// If a callback is provided, it will be called before ResolvedFields returns
+// when address resolution is not required. If address resolution is required,
+// the callback will be called once address resolution is complete, regardless
// of success or failure.
-func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
- r.mu.Lock()
-
- if !r.isResolutionRequiredRLocked() {
- // Nothing to do if there is no cache (which does the resolution on cache miss) or
- // link address is already known.
- r.mu.Unlock()
- return nil, nil
- }
-
- // Increment the route's reference count because finishResolution retains a
- // reference to the route and releases it when called.
- r.acquireLocked()
- r.mu.Unlock()
+//
+// Note, the route will not cache the remote link address when address
+// resolution completes.
+func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error {
+ _, _, err := r.resolvedFields(afterResolve)
+ return err
+}
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
+// resolvedFields is like ResolvedFields but also returns a notification channel
+// when address resolution is required. This channel will become readable once
+// address resolution is complete.
+//
+// The route's fields will also be returned, regardless of whether address
+// resolution is required or not.
+func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) {
+ r.mu.RLock()
+ fields := r.fieldsLocked()
+ resolutionRequired := r.isResolutionRequiredRLocked()
+ r.mu.RUnlock()
+ if !resolutionRequired {
+ if afterResolve != nil {
+ afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true})
+ }
+ return fields, nil, nil
}
// If specified, the local address used for link address resolution must be an
@@ -341,18 +361,27 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
- finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
- if ok {
- r.ResolveWith(linkAddress)
- }
+ afterResolveFields := fields
+ linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) {
if afterResolve != nil {
- afterResolve()
+ if r.Success {
+ afterResolveFields.RemoteLinkAddress = r.LinkAddress
+ }
+
+ afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success})
}
- r.Release()
+ })
+ if err == nil {
+ fields.RemoteLinkAddress = linkAddr
}
+ return fields, ch, err
+}
- _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
- return ch, err
+func (r *Route) nextHop() tcpip.Address {
+ if len(r.NextHop) == 0 {
+ return r.RemoteAddress
+ }
+ return r.NextHop
}
// local returns true if the route is a local route.
@@ -371,11 +400,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
- return false
- }
-
- return r.linkRes != nil
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
@@ -404,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
@@ -414,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
-func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
if !r.isValidForOutgoing() {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
@@ -424,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
@@ -496,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
+
+// ConfirmReachable informs the network/link layer that the neighbour used for
+// the route is reachable.
+//
+// "Reachable" is defined as having full-duplex communication between the
+// local and remote ends of the route.
+func (r *Route) ConfirmReachable() {
+ r.outgoingNIC.confirmReachable(r.nextHop())
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c0aec61a6..119c4c505 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -76,12 +76,16 @@ type TCPCubicState struct {
// TCPRACKState is used to hold a copy of the internal RACK state when the
// TCPProbeFunc is invoked.
type TCPRACKState struct {
- XmitTime time.Time
- EndSequence seqnum.Value
- FACK seqnum.Value
- RTT time.Duration
- Reord bool
- DSACKSeen bool
+ XmitTime time.Time
+ EndSequence seqnum.Value
+ FACK seqnum.Value
+ RTT time.Duration
+ Reord bool
+ DSACKSeen bool
+ ReoWnd time.Duration
+ ReoWndIncr uint8
+ ReoWndPersist int8
+ RTTSeq seqnum.Value
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -382,8 +386,6 @@ type Stack struct {
stats tcpip.Stats
- linkAddrCache *linkAddrCache
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
@@ -446,7 +448,7 @@ type Stack struct {
// sendBufferSize holds the min/default/max send buffer sizes for
// endpoints other than TCP.
- sendBufferSize SendBufferSizeOption
+ sendBufferSize tcpip.SendBufferSizeOption
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
@@ -554,7 +556,7 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
@@ -572,11 +574,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
- return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{}
}
case header.IPv6AddressSize:
if len(addr.Addr) == header.IPv4AddressSize {
- return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{}
}
}
@@ -584,10 +586,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
case netProto == t.NetProto:
case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
- return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{}
}
default:
- return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{}
}
return addr, netProto, nil
@@ -636,7 +638,6 @@ func New(opts Options) *Stack {
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
@@ -649,7 +650,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
- sendBufferSize: SendBufferSizeOption{
+ sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -701,10 +702,10 @@ func (s *Stack) UniqueID() uint64 {
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return netProto.SetOption(option)
}
@@ -718,10 +719,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op
// if err != nil {
// ...
// }
-func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return netProto.Option(option)
}
@@ -730,10 +731,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return transProtoState.proto.SetOption(option)
}
@@ -745,10 +746,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
-func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return transProtoState.proto.Option(option)
}
@@ -781,15 +782,15 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables packet forwarding between NICs for the
// passed protocol.
-func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error {
+func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
protocol, ok := s.networkProtocols[protocolNum]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
forwardingProtocol.SetForwarding(enable)
@@ -852,10 +853,10 @@ func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
}
// NewEndpoint creates a new transport layer endpoint of the given protocol.
-func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
t, ok := s.transportProtocols[transport]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return t.proto.NewEndpoint(network, waiterQueue)
@@ -864,9 +865,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// NewRawEndpoint creates a new raw transport layer endpoint of the given
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
-func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
if s.rawFactory == nil {
- return nil, tcpip.ErrNotPermitted
+ return nil, &tcpip.ErrNotPermitted{}
}
if !associated {
@@ -875,7 +876,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
t, ok := s.transportProtocols[transport]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return t.proto.NewRawEndpoint(network, waiterQueue)
@@ -883,9 +884,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// NewPacketEndpoint creates a new packet endpoint listening for the given
// netProto.
-func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if s.rawFactory == nil {
- return nil, tcpip.ErrNotPermitted
+ return nil, &tcpip.ErrNotPermitted{}
}
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
@@ -916,20 +917,20 @@ type NICOptions struct {
// NICs can be configured.
//
// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
-func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
// Make sure id is unique.
if _, ok := s.nics[id]; ok {
- return tcpip.ErrDuplicateNICID
+ return &tcpip.ErrDuplicateNICID{}
}
// Make sure name is unique, unless unnamed.
if opts.Name != "" {
for _, n := range s.nics {
if n.Name() == opts.Name {
- return tcpip.ErrDuplicateNICID
+ return &tcpip.ErrDuplicateNICID{}
}
}
}
@@ -945,7 +946,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
-func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
@@ -963,26 +964,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
-func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.enable()
}
// DisableNIC disables the given NIC.
-func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.disable()
@@ -1003,7 +1004,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
}
// RemoveNIC removes NIC and all related routes from the network stack.
-func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1013,10 +1014,10 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
// removeNICLocked removes NIC and all related routes from the network stack.
//
// s.mu must be locked.
-func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
delete(s.nics, id)
@@ -1050,6 +1051,9 @@ type NICInfo struct {
Stats NICStats
+ // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC.
+ NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats
+
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
@@ -1081,6 +1085,12 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
+
+ netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats)
+ for proto, netEP := range nic.networkEndpoints {
+ netStats[proto] = netEP.Stats()
+ }
+
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
@@ -1088,6 +1098,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
+ NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
@@ -1111,13 +1122,13 @@ type NICStateFlags struct {
}
// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
ap := tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: addr,
@@ -1127,16 +1138,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
netProto, ok := s.networkProtocols[protocol]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
Protocol: protocol,
@@ -1149,13 +1160,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.addAddress(protocolAddress, peb)
@@ -1163,7 +1174,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1171,7 +1182,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
return nic.removeAddress(addr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// AllAddresses returns a map of NICIDs to their protocol addresses (primary
@@ -1189,19 +1200,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
// GetMainNICAddress returns the first non-deprecated primary address and prefix
// for the given NIC and protocol. If no non-deprecated primary address exists,
-// a deprecated primary address and prefix will be returned. Returns an error if
+// a deprecated primary address and prefix will be returned. Returns false if
// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
// address for the given protocol.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ return tcpip.AddressWithPrefix{}, false
}
- return nic.primaryAddress(protocol), nil
+ return nic.primaryAddress(protocol), true
}
func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
@@ -1301,7 +1312,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1337,9 +1348,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return nil, tcpip.ErrBadLocalAddress
+ return nil, &tcpip.ErrBadLocalAddress{}
}
- return nil, tcpip.ErrNetworkUnreachable
+ return nil, &tcpip.ErrNetworkUnreachable{}
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1405,7 +1416,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
}
- return nil, tcpip.ErrNoRoute
+ return nil, &tcpip.ErrNoRoute{}
}
if id == 0 {
@@ -1425,12 +1436,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return nil, tcpip.ErrNoRoute
+ return nil, &tcpip.ErrNoRoute{}
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return nil, tcpip.ErrBadLocalAddress
+ return nil, &tcpip.ErrBadLocalAddress{}
}
- return nil, tcpip.ErrNetworkUnreachable
+ return nil, &tcpip.ErrNetworkUnreachable{}
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1476,13 +1487,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
}
// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
-func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.setPromiscuousMode(enable)
@@ -1492,13 +1503,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error
// SetSpoofing enables or disables address spoofing in the given NIC, allowing
// endpoints to bind to any address in the NIC.
-func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.setSpoofing(enable)
@@ -1506,17 +1517,27 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
return nil
}
-// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.add(fullAddr, linkAddr)
- // TODO: provide a way for a transport endpoint to receive a signal
- // that AddLinkAddress for a particular address has been called.
+// AddLinkAddress adds a link address for the neighbor on the specified NIC.
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr)
+ return nil
+}
+
+// LinkResolutionResult is the result of a link address resolution attempt.
+type LinkResolutionResult struct {
+ LinkAddress tcpip.LinkAddress
+ Success bool
}
-// GetLinkAddress finds the link address corresponding to a neighbor's address.
-//
-// Returns a link address for the remote address, if readily available.
+// GetLinkAddress finds the link address corresponding to a network address.
//
// Returns ErrNotSupported if the stack is not configured with a link address
// resolver for the specified network protocol.
@@ -1525,53 +1546,56 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
-// If onResolve is provided, it will be called either immediately, if
-// resolution is not required, or when address resolution is complete, with
-// the resolved link address and whether resolution succeeded. After any
-// callbacks have been called, the returned notification channel is closed.
+// onResolve will be called either immediately, if resolution is not required,
+// or when address resolution is complete, with the resolved link address and
+// whether resolution succeeded.
//
// If specified, the local address must be an address local to the interface
// the neighbor cache belongs to. The local address is the source address of
// a packet prompting NUD/link address resolution.
-//
-// TODO(gvisor.dev/issue/5151): Don't return the link address.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return "", nil, tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
linkRes, ok := s.linkAddrResolvers[protocol]
if !ok {
- return "", nil, tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
+ }
+
+ if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
+ return nil
}
- return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ return err
}
// Neighbors returns all IP to MAC address associations.
-func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
+func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return nil, tcpip.ErrUnknownNICID
+ return nil, &tcpip.ErrUnknownNICID{}
}
return nic.neighbors()
}
// AddStaticNeighbor statically associates an IP address to a MAC address.
-func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.addStaticNeighbor(addr, linkAddr)
@@ -1580,26 +1604,26 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd
// RemoveNeighbor removes an IP to MAC address association previously created
// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
// is no association with the provided address.
-func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.removeNeighbor(addr)
}
// ClearNeighbors removes all IP to MAC address associations.
-func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.clearNeighbors()
@@ -1609,25 +1633,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
// the stack transport dispatcher.
-func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
s.cleanupEndpointsMu.Unlock()
@@ -1652,13 +1676,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
-func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
-func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
@@ -1762,7 +1786,7 @@ func (s *Stack) Resume() {
// RegisterPacketEndpoint registers ep with the stack, causing it to receive
// all traffic of the specified netProto on the given NIC. If nicID is 0, it
// receives traffic from every NIC.
-func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1781,7 +1805,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network
// Capture on a specific device.
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
return err
@@ -1819,12 +1843,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
// WritePacketToRemote writes a payload on the specified NIC using the provided
// network protocol and remote link address.
-func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
pkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
@@ -1889,37 +1913,37 @@ func (s *Stack) RemoveTCPProbe() {
}
// JoinGroup joins the given multicast group on the given NIC.
-func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.joinGroup(protocol, multicastAddr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// LeaveGroup leaves the given multicast group on the given NIC.
-func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.leaveGroup(protocol, multicastAddr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// IsInGroup returns true if the NIC with ID nicID has joined the multicast
// group multicastAddr.
-func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.isInGroup(multicastAddr), nil
}
- return false, tcpip.ErrUnknownNICID
+ return false, &tcpip.ErrUnknownNICID{}
}
// IPTables returns the stack's iptables.
@@ -1959,26 +1983,26 @@ func (s *Stack) AllowICMPMessage() bool {
// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol
// number installed on the specified NIC.
-func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) {
+func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
nic, ok := s.nics[nicID]
if !ok {
- return nil, tcpip.ErrUnknownNICID
+ return nil, &tcpip.ErrUnknownNICID{}
}
return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
-func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
if !ok {
- return NUDConfigurations{}, tcpip.ErrUnknownNICID
+ return NUDConfigurations{}, &tcpip.ErrUnknownNICID{}
}
return nic.nudConfigs()
@@ -1988,13 +2012,13 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.setNUDConfigs(c)
@@ -2036,7 +2060,7 @@ func generateRandInt64() int64 {
}
// FindNetworkEndpoint returns the network endpoint for the given address.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2048,7 +2072,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
addressEndpoint.DecRef()
return nic.getNetworkEndpoint(netProto), nil
}
- return nil, tcpip.ErrBadAddress
+ return nil, &tcpip.ErrBadAddress{}
}
// FindNICNameFromID returns the name of the NIC for the given NICID.
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 0b093e6c5..8d9b20b7e 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -14,7 +14,9 @@
package stack
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// MinBufferSize is the smallest size of a receive or send buffer.
@@ -29,14 +31,6 @@ const (
DefaultMaxBufferSize = 4 << 20 // 4 MiB
)
-// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
-// get/set the default, min and max send buffer sizes.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max receive buffer sizes.
type ReceiveBufferSizeOption struct {
@@ -46,17 +40,17 @@ type ReceiveBufferSizeOption struct {
}
// SetOption allows setting stack wide options.
-func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+func (s *Stack) SetOption(option interface{}) tcpip.Error {
switch v := option.(type) {
- case SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
@@ -68,11 +62,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error {
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
@@ -81,14 +75,14 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option allows retrieving stack wide options.
-func (s *Stack) Option(option interface{}) *tcpip.Error {
+func (s *Stack) Option(option interface{}) tcpip.Error {
switch v := option.(type) {
- case *SendBufferSizeOption:
+ case *tcpip.SendBufferSizeOption:
s.mu.RLock()
*v = s.sendBufferSize
s.mu.RUnlock()
@@ -101,6 +95,6 @@ func (s *Stack) Option(option interface{}) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7e935ddff..41f95811f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,6 +60,15 @@ const (
protocolNumberOffset = 2
)
+func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
+ if addr, ok := s.GetMainNICAddress(nicID, proto); !ok {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto)
+ } else if addr != want {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want)
+ }
+ return nil
+}
+
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
// received packets; the counts of all endpoints are aggregated in the protocol
// descriptor.
@@ -81,7 +90,7 @@ type fakeNetworkEndpoint struct {
dispatcher stack.TransportDispatcher
}
-func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+func (f *fakeNetworkEndpoint) Enable() tcpip.Error {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.enabled = true
@@ -145,7 +154,7 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
-func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
@@ -153,7 +162,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -176,18 +185,30 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func (f *fakeNetworkEndpoint) Close() {
f.AddressableEndpointState.Cleanup()
}
+// Stats implements NetworkEndpoint.
+func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats {
+ return &fakeNetworkEndpointStats{}
+}
+
+var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil)
+
+type fakeNetworkEndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {}
+
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
@@ -202,15 +223,15 @@ type fakeNetworkProtocol struct {
}
}
-func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fakeNetNumber
}
-func (f *fakeNetworkProtocol) MinimumPacketSize() int {
+func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
+func (*fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
@@ -232,23 +253,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li
return e
}
-func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
f.defaultTTL = uint8(*v)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(f.defaultTTL)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -397,7 +418,7 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return err
@@ -406,7 +427,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r *stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -435,14 +456,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer
}
}
-func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
@@ -579,8 +600,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr,
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
_, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{})
}
}
@@ -628,8 +649,9 @@ func TestDisableUnknownNIC(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ err := s.DisableNIC(1)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
}
}
@@ -687,8 +709,9 @@ func TestRemoveUnknownNIC(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ err := s.RemoveNIC(1)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
}
}
@@ -731,8 +754,8 @@ func TestRemoveNIC(t *testing.T) {
func TestRouteWithDownNIC(t *testing.T) {
tests := []struct {
name string
- downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
- upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
}{
{
name: "Disabled NIC",
@@ -890,15 +913,15 @@ func TestRouteWithDownNIC(t *testing.T) {
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC2 after being brought down should fail.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
if upFn := test.upFn; upFn != nil {
// Writes with Routes that use NIC1 after being brought up should
@@ -911,7 +934,7 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
}
})
}
@@ -1036,11 +1059,12 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ err := s.RemoveAddress(1, localAddr)
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
}
}
@@ -1087,12 +1111,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ {
+ err := s.RemoveAddress(1, localAddr)
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
+ }
}
}
@@ -1186,7 +1213,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 2. Add Address, everything should work.
@@ -1214,7 +1241,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 4. Add Address back, everything should work again.
@@ -1253,8 +1280,8 @@ func TestEndpointExpiration(t *testing.T) {
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 7. Add Address back, everything should work again.
@@ -1290,7 +1317,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
})
}
@@ -1333,8 +1360,8 @@ func TestPromiscuousMode(t *testing.T) {
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{})
}
// Set promiscuous mode to false, then check that packet can't be
@@ -1540,7 +1567,7 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{})
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
@@ -1590,8 +1617,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ {
+ _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
+ t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
+ }
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
@@ -1610,8 +1640,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ {
+ _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
+ }
}
}
@@ -1753,9 +1786,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
anyAddr = header.IPv6Any
}
- want := tcpip.ErrNetworkUnreachable
+ var want tcpip.Error = &tcpip.ErrNetworkUnreachable{}
if tc.routeNeeded {
- want = tcpip.ErrNoRoute
+ want = &tcpip.ErrNoRoute{}
}
// If there is no endpoint, it won't work.
@@ -1769,8 +1802,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
// Route table is empty but we need a route, this should cause an error.
- if err != tcpip.ErrNoRoute {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{})
}
} else {
if err != nil {
@@ -1861,20 +1894,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
+ gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if len(primaryAddrAdded) == 0 {
// No primary addresses present.
if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr)
}
} else {
// At least one primary address was added, verify the returned
// address is in the list of primary addresses we added.
if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded)
}
}
})
@@ -1915,12 +1948,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
- t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
@@ -1928,12 +1957,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get no address after removal.
- gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2102,7 +2127,7 @@ func TestCreateNICWithOptions(t *testing.T) {
type callArgsAndExpect struct {
nicID tcpip.NICID
opts stack.NICOptions
- err *tcpip.Error
+ err tcpip.Error
}
tests := []struct {
@@ -2120,7 +2145,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(1),
opts: stack.NICOptions{Name: "eth2"},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2135,7 +2160,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(2),
opts: stack.NICOptions{Name: "lo"},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2165,7 +2190,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(1),
opts: stack.NICOptions{},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2474,12 +2499,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ // Check that we get no address after removal.
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
- if gotMainAddr != expectedMainAddr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2525,12 +2550,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2561,12 +2582,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// Address should not be considered bound to the
// NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
@@ -2584,12 +2601,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
+ t.Fatal(err)
}
}
@@ -2621,17 +2634,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
t.Fatal("AddAddressWithOptions failed:", err)
}
- addr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ addr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if pi == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
{
@@ -2710,18 +2723,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.RemoveAddress(1, "\x03"); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
- addr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatalf("s.GetMainNICAddress failed: %v", err)
+ addr, ok = s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
-
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else {
if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
}
})
@@ -3247,12 +3259,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should be tentative so it should not be a main address.
- got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC should start DAD for the address.
@@ -3264,12 +3272,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -3284,12 +3288,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC again should be a no-op.
@@ -3299,12 +3299,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -3313,14 +3309,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
testCases := []struct {
name string
rs stack.ReceiveBufferSizeOption
- err *tcpip.Error
+ err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
- {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
{"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
@@ -3352,31 +3348,32 @@ func TestStackSendBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- ss stack.SendBufferSizeOption
- err *tcpip.Error
+ ss tcpip.SendBufferSizeOption
+ err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
- {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
- {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := stack.New(stack.Options{})
defer s.Close()
- if err := s.SetOption(tc.ss); err != tc.err {
- t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ err := s.SetOption(tc.ss)
+ if diff := cmp.Diff(tc.err, err); diff != "" {
+ t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff)
}
- var ss stack.SendBufferSizeOption
if tc.err == nil {
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err != nil {
t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
}
@@ -3778,20 +3775,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Should still get the address when the NIC is diabled.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("DisableNIC(%d): %s", nicID, err)
}
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -3939,7 +3932,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
addrNIC tcpip.NICID
localAddr tcpip.Address
- findRouteErr *tcpip.Error
+ findRouteErr tcpip.Error
dependentOnForwarding bool
}{
{
@@ -3948,7 +3941,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic2Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -3957,7 +3950,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: true,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic2Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -3966,7 +3959,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4002,7 +3995,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID2,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4011,7 +4004,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: true,
addrNIC: nicID2,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4035,7 +4028,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4051,7 +4044,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
addrNIC: nicID1,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4059,7 +4052,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
addrNIC: nicID1,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4067,7 +4060,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4075,7 +4068,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4107,7 +4100,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4115,7 +4108,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4123,7 +4116,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4131,7 +4124,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4186,8 +4179,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if r != nil {
defer r.Release()
}
- if err != test.findRouteErr {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
@@ -4234,8 +4227,11 @@ func TestFindRouteWithForwarding(t *testing.T) {
if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
}
- if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ {
+ err := send(r, data)
+ if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{})
+ }
}
if n := ep1.Drain(); n != 0 {
t.Errorf("got %d unexpected packets from ep1", n)
@@ -4297,8 +4293,9 @@ func TestWritePacketToRemote(t *testing.T) {
}
t.Run("InvalidNICID", func(t *testing.T) {
- if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
- t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView())
+ if _, ok := err.(*tcpip.ErrUnknownDevice); !ok {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{})
}
pkt, ok := e.Read()
if got, want := ok, false; got != want {
@@ -4372,10 +4369,64 @@ func TestGetLinkAddressErrors(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID)
+ {
+ err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{})
+ }
+ }
+ {
+ err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{})
+ }
+ }
+}
+
+func TestStaticGetLinkAddress(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4",
+ proto: ipv4.ProtocolNumber,
+ addr: header.IPv4Broadcast,
+ expectedLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "IPv6",
+ proto: ipv6.ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
}
- if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err)
+ }
+
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 07b2818d2..26eceb804 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -205,7 +205,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
@@ -222,7 +222,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -294,7 +294,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
@@ -306,7 +306,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// checkEndpoint checks if an endpoint can be registered with the dispatcher.
-func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for _, n := range netProtos {
if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
return err
@@ -403,7 +403,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
@@ -412,7 +412,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
}
@@ -422,7 +422,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
return nil
}
-func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error {
ep.mu.RLock()
defer ep.mu.RUnlock()
@@ -431,7 +431,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
}
@@ -456,7 +456,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports
return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
@@ -464,7 +464,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
eps.mu.Lock()
@@ -482,7 +482,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
}
-func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
@@ -490,7 +490,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
eps.mu.RLock()
@@ -649,10 +649,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
-func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
eps.mu.Lock()
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 57e1f8354..10cbbe589 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) {
for _, test := range []struct {
name string
proto tcpip.NetworkProtocolNumber
- want *tcpip.Error
+ want tcpip.Error
}{
- {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
{"success", ipv4.ProtocolNumber, nil},
} {
t.Run(test.name, func(t *testing.T) {
@@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
@@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer wq.EventUnregister(&we)
defer close(ch)
- var err *tcpip.Error
+ var err tcpip.Error
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 9d39533a1..cf5de747b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"io"
"testing"
@@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
-func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
- ep.ops.InitHandler(ep)
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
return ep
}
@@ -86,19 +87,20 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
return tcpip.ReadResult{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
Data: buffer.View(v).ToVectorisedView(),
@@ -112,42 +114,42 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
}
// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return -1, tcpip.ErrUnknownProtocolOption
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*fakeTransportEndpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
r.Release()
return err
@@ -162,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 {
return f.uniqueID
}
-func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Reset() {
}
-func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
+func (*fakeTransportEndpoint) Listen(int) tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -186,9 +188,8 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
return a, nil, nil
}
-func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
+func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error {
if err := f.proto.stack.RegisterTransportEndpoint(
- a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
@@ -202,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{}, nil
}
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{}, nil
}
@@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress,
route: route,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}
@@ -251,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
func (*fakeTransportEndpoint) Wait() {}
-func (*fakeTransportEndpoint) LastError() *tcpip.Error {
+func (*fakeTransportEndpoint) LastError() tcpip.Error {
return nil
}
@@ -279,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return nil, tcpip.ErrUnknownProtocol
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return nil, &tcpip.ErrUnknownProtocol{}
}
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) {
return 0, 0, nil
}
@@ -299,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp
return stack.UnknownDestinationPacketHandled
}
-func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPModerateReceiveBufferOption:
f.opts.good = bool(*v)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPModerateReceiveBufferOption:
*v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) {
}
// Create buffer that will hold the payload.
- view := buffer.NewView(30)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ b := make([]byte, 30)
+ var r bytes.Reader
+ r.Reset(b)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 56aac093c..c500a0d1c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -29,6 +29,7 @@
package tcpip
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -46,141 +47,6 @@ import (
// Using header.IPv4AddressSize would cause an import cycle.
const ipv4AddressSize = 4
-// Error represents an error in the netstack error space. Using a special type
-// ensures that errors outside of this space are not accidentally introduced.
-//
-// All errors must have unique msg strings.
-//
-// +stateify savable
-type Error struct {
- msg string
-
- ignoreStats bool
-}
-
-// String implements fmt.Stringer.String.
-func (e *Error) String() string {
- if e == nil {
- return "<nil>"
- }
- return e.msg
-}
-
-// IgnoreStats indicates whether this error type should be included in failure
-// counts in tcpip.Stats structs.
-func (e *Error) IgnoreStats() bool {
- return e.ignoreStats
-}
-
-// Errors that can be returned by the network stack.
-var (
- ErrUnknownProtocol = &Error{msg: "unknown protocol"}
- ErrUnknownNICID = &Error{msg: "unknown nic id"}
- ErrUnknownDevice = &Error{msg: "unknown device"}
- ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
- ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
- ErrDuplicateAddress = &Error{msg: "duplicate address"}
- ErrNoRoute = &Error{msg: "no route"}
- ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
- ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
- ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
- ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
- ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
- ErrNoPortAvailable = &Error{msg: "no ports are available"}
- ErrPortInUse = &Error{msg: "port is in use"}
- ErrBadLocalAddress = &Error{msg: "bad local address"}
- ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
- ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
- ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
- ErrConnectionRefused = &Error{msg: "connection was refused"}
- ErrTimeout = &Error{msg: "operation timed out"}
- ErrAborted = &Error{msg: "operation aborted"}
- ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
- ErrDestinationRequired = &Error{msg: "destination address is required"}
- ErrNotSupported = &Error{msg: "operation not supported"}
- ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
- ErrNotConnected = &Error{msg: "endpoint not connected"}
- ErrConnectionReset = &Error{msg: "connection reset by peer"}
- ErrConnectionAborted = &Error{msg: "connection aborted"}
- ErrNoSuchFile = &Error{msg: "no such file"}
- ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
- ErrBadAddress = &Error{msg: "bad address"}
- ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
- ErrMessageTooLong = &Error{msg: "message too long"}
- ErrNoBufferSpace = &Error{msg: "no buffer space available"}
- ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
- ErrNotPermitted = &Error{msg: "operation not permitted"}
- ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
- ErrMalformedHeader = &Error{msg: "header is malformed"}
- ErrBadBuffer = &Error{msg: "bad buffer"}
-)
-
-var messageToError map[string]*Error
-
-var populate sync.Once
-
-// StringToError converts an error message to the error.
-func StringToError(s string) *Error {
- populate.Do(func() {
- var errors = []*Error{
- ErrUnknownProtocol,
- ErrUnknownNICID,
- ErrUnknownDevice,
- ErrUnknownProtocolOption,
- ErrDuplicateNICID,
- ErrDuplicateAddress,
- ErrNoRoute,
- ErrBadLinkEndpoint,
- ErrAlreadyBound,
- ErrInvalidEndpointState,
- ErrAlreadyConnecting,
- ErrAlreadyConnected,
- ErrNoPortAvailable,
- ErrPortInUse,
- ErrBadLocalAddress,
- ErrClosedForSend,
- ErrClosedForReceive,
- ErrWouldBlock,
- ErrConnectionRefused,
- ErrTimeout,
- ErrAborted,
- ErrConnectStarted,
- ErrDestinationRequired,
- ErrNotSupported,
- ErrQueueSizeNotSupported,
- ErrNotConnected,
- ErrConnectionReset,
- ErrConnectionAborted,
- ErrNoSuchFile,
- ErrInvalidOptionValue,
- ErrBadAddress,
- ErrNetworkUnreachable,
- ErrMessageTooLong,
- ErrNoBufferSpace,
- ErrBroadcastDisabled,
- ErrNotPermitted,
- ErrAddressFamilyNotSupported,
- ErrMalformedHeader,
- ErrBadBuffer,
- }
-
- messageToError = make(map[string]*Error)
- for _, e := range errors {
- if messageToError[e.String()] != nil {
- panic("tcpip errors with duplicated message: " + e.String())
- }
- messageToError[e.String()] = e
- }
- })
-
- e, ok := messageToError[s]
- if !ok {
- panic("unknown error message: " + s)
- }
-
- return e
-}
-
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
@@ -194,7 +60,7 @@ type ErrSaveRejection struct {
}
// Error returns a sensible description of the save rejection error.
-func (e ErrSaveRejection) Error() string {
+func (e *ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
@@ -471,30 +337,15 @@ type FullAddress struct {
// This interface allows the endpoint to request the amount of data it needs
// based on internal buffers without exposing them.
type Payloader interface {
- // FullPayload returns all available bytes.
- FullPayload() ([]byte, *Error)
+ io.Reader
- // Payload returns a slice containing at most size bytes.
- Payload(size int) ([]byte, *Error)
+ // Len returns the number of bytes of the unread portion of the
+ // Reader.
+ Len() int
}
-// SlicePayload implements Payloader for slices.
-//
-// This is typically used for tests.
-type SlicePayload []byte
-
-// FullPayload implements Payloader.FullPayload.
-func (s SlicePayload) FullPayload() ([]byte, *Error) {
- return s, nil
-}
-
-// Payload implements Payloader.Payload.
-func (s SlicePayload) Payload(size int) ([]byte, *Error) {
- if size > len(s) {
- size = len(s)
- }
- return s[:size], nil
-}
+var _ Payloader = (*bytes.Buffer)(nil)
+var _ Payloader = (*bytes.Reader)(nil)
var _ io.Writer = (*SliceWriter)(nil)
@@ -647,7 +498,7 @@ type Endpoint interface {
// If non-zero number of bytes are successfully read and written to dst, err
// must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
// should be returned.
- Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error)
+ Read(io.Writer, ReadOptions) (ReadResult, Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -662,7 +513,7 @@ type Endpoint interface {
// stream (TCP) Endpoints may return partial writes, and even then only
// in the case where writing additional data would block. Other Endpoints
// will either write the entire message or return an error.
- Write(Payloader, WriteOptions) (int64, *Error)
+ Write(Payloader, WriteOptions) (int64, Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -676,21 +527,21 @@ type Endpoint interface {
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
//
- // If address.Addr is empty, this means that Enpoint has to be
+ // If address.Addr is empty, this means that Endpoint has to be
// disconnected if this is supported, otherwise
// ErrAddressFamilyNotSupported must be returned.
- Connect(address FullAddress) *Error
+ Connect(address FullAddress) Error
// Disconnect disconnects the endpoint from its peer.
- Disconnect() *Error
+ Disconnect() Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
- Shutdown(flags ShutdownFlags) *Error
+ Shutdown(flags ShutdownFlags) Error
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
- Listen(backlog int) *Error
+ Listen(backlog int) Error
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode. This method does not
@@ -700,36 +551,36 @@ type Endpoint interface {
//
// If peerAddr is not nil then it is populated with the peer address of the
// returned endpoint.
- Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error)
+ Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
- Bind(address FullAddress) *Error
+ Bind(address FullAddress) Error
// GetLocalAddress returns the address to which the endpoint is bound.
- GetLocalAddress() (FullAddress, *Error)
+ GetLocalAddress() (FullAddress, Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
- GetRemoteAddress() (FullAddress, *Error)
+ GetRemoteAddress() (FullAddress, Error)
// Readiness returns the current readiness of the endpoint. For example,
// if waiter.EventIn is set, the endpoint is immediately readable.
Readiness(mask waiter.EventMask) waiter.EventMask
// SetSockOpt sets a socket option.
- SetSockOpt(opt SettableSocketOption) *Error
+ SetSockOpt(opt SettableSocketOption) Error
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOptInt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) Error
// GetSockOpt gets a socket option.
- GetSockOpt(opt GettableSocketOption) *Error
+ GetSockOpt(opt GettableSocketOption) Error
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOptInt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -752,7 +603,7 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
// LastError clears and returns the last error reported by the endpoint.
- LastError() *Error
+ LastError() Error
// SocketOptions returns the structure which contains all the socket
// level options.
@@ -840,10 +691,6 @@ const (
// number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption
- // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
- // specify the send buffer size option.
- SendBufferSizeOption
-
// ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the receive buffer size option.
ReceiveBufferSizeOption
@@ -1011,12 +858,54 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
+// CongestionControlState indicates the current congestion control state for
+// TCP sender.
+type CongestionControlState int
+
+const (
+ // Open indicates that the sender is receiving acks in order and
+ // no loss or dupACK's etc have been detected.
+ Open CongestionControlState = iota
+ // RTORecovery indicates that an RTO has occurred and the sender
+ // has entered an RTO based recovery phase.
+ RTORecovery
+ // FastRecovery indicates that the sender has entered FastRecovery
+ // based on receiving nDupAck's. This state is entered only when
+ // SACK is not in use.
+ FastRecovery
+ // SACKRecovery indicates that the sender has entered SACK based
+ // recovery.
+ SACKRecovery
+ // Disorder indicates the sender either received some SACK blocks
+ // or dupACK's.
+ Disorder
+)
+
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
type TCPInfoOption struct {
- RTT time.Duration
+ // RTT is the smoothed round trip time.
+ RTT time.Duration
+
+ // RTTVar is the round trip time variation.
RTTVar time.Duration
+
+ // RTO is the retransmission timeout for the endpoint.
+ RTO time.Duration
+
+ // CcState is the congestion control state.
+ CcState CongestionControlState
+
+ // SndCwnd is the congestion window, in packets.
+ SndCwnd uint32
+
+ // SndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ SndSsthresh uint32
+
+ // ReorderSeen indicates if reordering is seen in the endpoint.
+ ReorderSeen bool
}
func (*TCPInfoOption) isGettableSocketOption() {}
@@ -1248,6 +1137,31 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ // Min is the minimum size for send buffer.
+ Min int
+
+ // Default is the default size for send buffer.
+ Default int
+
+ // Max is the maximum size for send buffer.
+ Max int
+}
+
+// GetSendBufferLimits is used to get the send buffer size limits.
+type GetSendBufferLimits func(StackHandler) SendBufferSizeOption
+
+// GetStackSendBufferLimits is used to get default, min and max send buffer size.
+func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption {
+ var ss SendBufferSizeOption
+ if err := so.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ return ss
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -1317,8 +1231,33 @@ func (s *StatCounter) String() string {
return strconv.FormatUint(s.Value(), 10)
}
+// A MultiCounterStat keeps track of two counters at once.
+type MultiCounterStat struct {
+ a, b *StatCounter
+}
+
+// Init sets both internal counters to point to a and b.
+func (m *MultiCounterStat) Init(a, b *StatCounter) {
+ m.a = a
+ m.b = b
+}
+
+// Increment adds one to the counters.
+func (m *MultiCounterStat) Increment() {
+ m.a.Increment()
+ m.b.Increment()
+}
+
+// IncrementBy increments the counters by v.
+func (m *MultiCounterStat) IncrementBy(v uint64) {
+ m.a.IncrementBy(v)
+ m.b.IncrementBy(v)
+}
+
// ICMPv4PacketStats enumerates counts for all ICMPv4 packet types.
type ICMPv4PacketStats struct {
+ // LINT.IfChange(ICMPv4PacketStats)
+
// Echo is the total number of ICMPv4 echo packets counted.
Echo *StatCounter
@@ -1358,10 +1297,56 @@ type ICMPv4PacketStats struct {
// InfoReply is the total number of ICMPv4 information reply packets
// counted.
InfoReply *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats)
+}
+
+// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
+type ICMPv4SentPacketStats struct {
+ // LINT.IfChange(ICMPv4SentPacketStats)
+
+ ICMPv4PacketStats
+
+ // Dropped is the total number of ICMPv4 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv4 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats)
+}
+
+// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
+type ICMPv4ReceivedPacketStats struct {
+ // LINT.IfChange(ICMPv4ReceivedPacketStats)
+
+ ICMPv4PacketStats
+
+ // Invalid is the total number of invalid ICMPv4 packets received.
+ Invalid *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats)
+}
+
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
+ // LINT.IfChange(ICMPv4Stats)
+
+ // PacketsSent contains statistics about sent packets.
+ PacketsSent ICMPv4SentPacketStats
+
+ // PacketsReceived contains statistics about received packets.
+ PacketsReceived ICMPv4ReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4Stats)
}
// ICMPv6PacketStats enumerates counts for all ICMPv6 packet types.
type ICMPv6PacketStats struct {
+ // LINT.IfChange(ICMPv6PacketStats)
+
// EchoRequest is the total number of ICMPv6 echo request packets
// counted.
EchoRequest *StatCounter
@@ -1416,32 +1401,14 @@ type ICMPv6PacketStats struct {
// MulticastListenerDone is the total number of Multicast Listener Done
// messages counted.
MulticastListenerDone *StatCounter
-}
-
-// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
-type ICMPv4SentPacketStats struct {
- ICMPv4PacketStats
-
- // Dropped is the total number of ICMPv4 packets dropped due to link
- // layer errors.
- Dropped *StatCounter
- // RateLimited is the total number of ICMPv6 packets dropped due to
- // rate limit being exceeded.
- RateLimited *StatCounter
-}
-
-// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
-type ICMPv4ReceivedPacketStats struct {
- ICMPv4PacketStats
-
- // Invalid is the total number of ICMPv4 packets received that the
- // transport layer could not parse.
- Invalid *StatCounter
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats)
}
// ICMPv6SentPacketStats collects outbound ICMPv6-specific stats.
type ICMPv6SentPacketStats struct {
+ // LINT.IfChange(ICMPv6SentPacketStats)
+
ICMPv6PacketStats
// Dropped is the total number of ICMPv6 packets dropped due to link
@@ -1451,47 +1418,41 @@ type ICMPv6SentPacketStats struct {
// RateLimited is the total number of ICMPv6 packets dropped due to
// rate limit being exceeded.
RateLimited *StatCounter
+
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats)
}
// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
type ICMPv6ReceivedPacketStats struct {
+ // LINT.IfChange(ICMPv6ReceivedPacketStats)
+
ICMPv6PacketStats
// Unrecognized is the total number of ICMPv6 packets received that the
// transport layer does not know how to parse.
Unrecognized *StatCounter
- // Invalid is the total number of ICMPv6 packets received that the
- // transport layer could not parse.
+ // Invalid is the total number of invalid ICMPv6 packets received.
Invalid *StatCounter
// RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets
// dropped due to being router-specific packets.
RouterOnlyPacketsDroppedByHost *StatCounter
-}
-
-// ICMPv4Stats collects ICMPv4-specific stats.
-type ICMPv4Stats struct {
- // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
- // and a single count of packets which failed to write to the link
- // layer.
- PacketsSent ICMPv4SentPacketStats
- // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
- // packet type and a single count of invalid packets received.
- PacketsReceived ICMPv4ReceivedPacketStats
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats)
}
// ICMPv6Stats collects ICMPv6-specific stats.
type ICMPv6Stats struct {
- // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
- // and a single count of packets which failed to write to the link
- // layer.
+ // LINT.IfChange(ICMPv6Stats)
+
+ // PacketsSent contains statistics about sent packets.
PacketsSent ICMPv6SentPacketStats
- // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
- // packet type and a single count of invalid packets received.
+ // PacketsReceived contains statistics about received packets.
PacketsReceived ICMPv6ReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6Stats)
}
// ICMPStats collects ICMP-specific stats (both v4 and v6).
@@ -1505,6 +1466,8 @@ type ICMPStats struct {
// IGMPPacketStats enumerates counts for all IGMP packet types.
type IGMPPacketStats struct {
+ // LINT.IfChange(IGMPPacketStats)
+
// MembershipQuery is the total number of Membership Query messages counted.
MembershipQuery *StatCounter
@@ -1518,22 +1481,29 @@ type IGMPPacketStats struct {
// LeaveGroup is the total number of Leave Group messages counted.
LeaveGroup *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats)
}
// IGMPSentPacketStats collects outbound IGMP-specific stats.
type IGMPSentPacketStats struct {
+ // LINT.IfChange(IGMPSentPacketStats)
+
IGMPPacketStats
// Dropped is the total number of IGMP packets dropped.
Dropped *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats)
}
// IGMPReceivedPacketStats collects inbound IGMP-specific stats.
type IGMPReceivedPacketStats struct {
+ // LINT.IfChange(IGMPReceivedPacketStats)
+
IGMPPacketStats
- // Invalid is the total number of IGMP packets received that IGMP could not
- // parse.
+ // Invalid is the total number of invalid IGMP packets received.
Invalid *StatCounter
// ChecksumErrors is the total number of IGMP packets dropped due to bad
@@ -1543,21 +1513,27 @@ type IGMPReceivedPacketStats struct {
// Unrecognized is the total number of unrecognized messages counted, these
// are silently ignored for forward-compatibilty.
Unrecognized *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats)
}
-// IGMPStats colelcts IGMP-specific stats.
+// IGMPStats collects IGMP-specific stats.
type IGMPStats struct {
- // IGMPSentPacketStats contains counts of sent packets by IGMP packet type
- // and a single count of invalid packets received.
+ // LINT.IfChange(IGMPStats)
+
+ // PacketsSent contains statistics about sent packets.
PacketsSent IGMPSentPacketStats
- // IGMPReceivedPacketStats contains counts of received packets by IGMP packet
- // type and a single count of invalid packets received.
+ // PacketsReceived contains statistics about received packets.
PacketsReceived IGMPReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats)
}
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
+ // LINT.IfChange(IPStats)
+
// PacketsReceived is the total number of IP packets received from the
// link layer.
PacketsReceived *StatCounter
@@ -1575,7 +1551,7 @@ type IPStats struct {
InvalidSourceAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
- // are successfully delivered to the transport layer via HandlePacket.
+ // are successfully delivered to the transport layer.
PacketsDelivered *StatCounter
// PacketsSent is the total number of IP packets sent via WritePacket.
@@ -1613,10 +1589,14 @@ type IPStats struct {
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived *StatCounter
+
+ // LINT.ThenChange(network/ip/stats.go:MultiCounterIPStats)
}
// ARPStats collects ARP-specific stats.
type ARPStats struct {
+ // LINT.IfChange(ARPStats)
+
// PacketsReceived is the number of ARP packets received from the link layer.
PacketsReceived *StatCounter
@@ -1644,10 +1624,6 @@ type ARPStats struct {
// ARP request with a bad local address.
OutgoingRequestBadLocalAddressErrors *StatCounter
- // OutgoingRequestNetworkUnreachableErrors is the number of failures to send
- // an ARP request with a network unreachable error.
- OutgoingRequestNetworkUnreachableErrors *StatCounter
-
// OutgoingRequestsDropped is the number of ARP requests which failed to write
// to a link-layer endpoint.
OutgoingRequestsDropped *StatCounter
@@ -1666,6 +1642,8 @@ type ARPStats struct {
// OutgoingRepliesSent is the number of ARP replies successfully written to a
// link-layer endpoint.
OutgoingRepliesSent *StatCounter
+
+ // LINT.ThenChange(network/arp/stats.go:multiCounterARPStats)
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 1742a178d..71695b630 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = [
"forward_test.go",
+ "iptables_test.go",
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
@@ -16,6 +17,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ac9670f9a..38e1881c7 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -38,96 +38,207 @@ import (
var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil)
var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil)
-// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
-// link endpoint and checks the destination link address before delivering
-// network packets to the network dispatcher.
-//
-// See ethernet.Endpoint for more details.
-func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
- var e endpointWithDestinationCheck
- e.Endpoint.Init(ethernet.New(ep), &e)
- return &e
-}
-
-// endpointWithDestinationCheck is a link endpoint that checks the destination
-// link address before delivering network packets to the network dispatcher.
-type endpointWithDestinationCheck struct {
- nested.Endpoint
-}
-
-// DeliverNetworkPacket implements stack.NetworkDispatcher.
-func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
- e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
- }
-}
-
-func TestForwarding(t *testing.T) {
- const (
- host1NICID = 1
- routerNICID1 = 2
- routerNICID2 = 3
- host2NICID = 4
-
- listenPort = 8080
- )
+const (
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+)
- host1IPv4Addr := tcpip.ProtocolAddress{
+var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 24,
},
}
- routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ routerNIC1IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
- routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ routerNIC2IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
PrefixLen: 8,
},
}
- host2IPv4Addr := tcpip.ProtocolAddress{
+ host2IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
PrefixLen: 8,
},
}
- host1IPv6Addr := tcpip.ProtocolAddress{
+ host1IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::2").To16()),
PrefixLen: 64,
},
}
- routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ routerNIC1IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::1").To16()),
PrefixLen: 64,
},
}
- routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ routerNIC2IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("b::1").To16()),
PrefixLen: 64,
},
}
- host2IPv6Addr := tcpip.ProtocolAddress{
+ host2IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("b::2").To16()),
PrefixLen: 64,
},
}
+)
+
+func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) {
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
+
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ {
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+}
+
+// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
+// link endpoint and checks the destination link address before delivering
+// network packets to the network dispatcher.
+//
+// See ethernet.Endpoint for more details.
+func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
+ var e endpointWithDestinationCheck
+ e.Endpoint.Init(ethernet.New(ep), &e)
+ return &e
+}
+
+// endpointWithDestinationCheck is a link endpoint that checks the destination
+// link address before delivering network packets to the network dispatcher.
+type endpointWithDestinationCheck struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+}
+
+func TestForwarding(t *testing.T) {
+ const listenPort = 8080
type endpointAndAddresses struct {
serverEP tcpip.Endpoint
@@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) {
subTests := []struct {
name string
proto tcpip.TransportProtocolNumber
- expectedConnectErr *tcpip.Error
+ expectedConnectErr tcpip.Error
setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
needRemoteAddr bool
}{
@@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) {
{
name: "TCP",
proto: tcp.ProtocolNumber,
- expectedConnectErr: tcpip.ErrConnectStarted,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
@@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) {
var addr tcpip.FullAddress
for {
newEP, wq, err := ep.Accept(&addr)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-ch
continue
}
@@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) {
host1Stack := stack.New(stackOpts)
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
-
- host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
- routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
-
- if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
- }
- if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
- }
- if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- })
- routerStack.SetRouteTable([]tcpip.Route{
- {
- Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- {
- Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- })
+ setupRoutedStacks(t, host1Stack, routerStack, host2Stack)
epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
defer epsAndAddrs.serverEP.Close()
@@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) {
t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
}
- if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr {
- t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr)
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
}
if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
@@ -436,9 +444,10 @@ func TestForwarding(t *testing.T) {
write := func(ep tcpip.Endpoint, data []byte) {
t.Helper()
- dataPayload := tcpip.SlicePayload(data)
+ var r bytes.Reader
+ r.Reset(data)
var wOpts tcpip.WriteOptions
- n, err := ep.Write(dataPayload, wOpts)
+ n, err := ep.Write(&r, wOpts)
if err != nil {
t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
}
@@ -486,7 +495,7 @@ func TestForwarding(t *testing.T) {
read(serverCH, serverEP, data, clientAddr)
- data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
write(serverEP, data)
read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
})
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
new file mode 100644
index 000000000..21a8dd291
--- /dev/null
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -0,0 +1,336 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type inputIfNameMatcher struct {
+ name string
+}
+
+var _ stack.Matcher = (*inputIfNameMatcher)(nil)
+
+func (*inputIfNameMatcher) Name() string {
+ return "inputIfNameMatcher"
+}
+
+func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
+ return (hook == stack.Input && im.name != "" && im.name == inNicName), false
+}
+
+const (
+ nicID = 1
+ nicName = "nic1"
+ anotherNicName = "nic2"
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ srcAddrV4 = "\x0a\x00\x00\x01"
+ dstAddrV4 = "\x0a\x00\x00\x02"
+ srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ payloadSize = 20
+)
+
+func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ }
+ return s, e
+}
+
+func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ }
+ return s, e
+}
+
+func genPacketV6() *stack.PacketBuffer {
+ pktSize := header.IPv6MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv6(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadSize,
+ TransportProtocol: 99,
+ HopLimit: 255,
+ SrcAddr: srcAddrV6,
+ DstAddr: dstAddrV6,
+ })
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func genPacketV4() *stack.PacketBuffer {
+ pktSize := header.IPv4MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv4Fields{
+ TOS: 0,
+ TotalLength: uint16(pktSize),
+ ID: 1,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: 48,
+ Protocol: 99,
+ SrcAddr: srcAddrV4,
+ DstAddr: dstAddrV4,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func TestIPTablesStatsForInput(t *testing.T) {
+ tests := []struct {
+ name string
+ setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func() *stack.PacketBuffer
+ proto tcpip.NetworkProtocolNumber
+ expectReceived int
+ expectInputDropped int
+ }{
+ {
+ name: "IPv6 Accept",
+ setupStack: genStackV6,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept",
+ setupStack: genStackV4,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface matches)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface matches)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := test.setupStack(t)
+ test.setupFilter(t, s)
+ e.InjectInbound(test.proto, test.genPacket())
+
+ if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
+ t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
+ }
+ if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
+ t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index af32d3009..7069352f2 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -19,11 +19,14 @@ import (
"fmt"
"net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
@@ -32,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -207,8 +211,10 @@ func TestPing(t *testing.T) {
defer ep.Close()
icmpBuf := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(icmpBuf)
wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
} else if want := int64(len(icmpBuf)); n != want {
t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
@@ -251,7 +257,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name string
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
- expectedWriteErr *tcpip.Error
+ expectedWriteErr tcpip.Error
sockError tcpip.SockError
}{
{
@@ -270,9 +276,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name: "IPv4 without resolvable remote",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
sockError: tcpip.SockError{
- Err: tcpip.ErrNoRoute,
+ Err: &tcpip.ErrNoRoute{},
ErrType: byte(header.ICMPv4DstUnreachable),
ErrCode: byte(header.ICMPv4HostUnreachable),
ErrOrigin: tcpip.SockExtErrorOriginICMP,
@@ -292,9 +298,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name: "IPv6 without resolvable remote",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
sockError: tcpip.SockError{
- Err: tcpip.ErrNoRoute,
+ Err: &tcpip.ErrNoRoute{},
ErrType: byte(header.ICMPv6DstUnreachable),
ErrCode: byte(header.ICMPv6AddressUnreachable),
ErrOrigin: tcpip.SockExtErrorOriginICMP6,
@@ -351,16 +357,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
remoteAddr := listenerAddr
remoteAddr.Addr = test.remoteAddr
- if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted)
+ {
+ err := clientEP.Connect(remoteAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{})
+ }
}
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
<-ch
- var wOpts tcpip.WriteOptions
- if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr {
- t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ _, err := clientEP.Write(&r, wOpts)
+ if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" {
+ t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff)
+ }
}
if test.expectedWriteErr == nil {
@@ -374,7 +388,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
sockErrCmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(tcpip.SockError{}),
- cmp.Comparer(func(a, b *tcpip.Error) bool {
+ cmp.Comparer(func(a, b tcpip.Error) bool {
// tcpip.Error holds an unexported field but the errors netstack uses
// are pre defined so we can simply compare pointers.
return a == b
@@ -404,20 +418,134 @@ func TestGetLinkAddress(t *testing.T) {
)
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedLinkAddr bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedOk bool
}{
{
- name: "IPv4",
+ name: "IPv4 resolvable",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
},
{
- name: "IPv6",
+ name: "IPv6 resolvable",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, useNeighborCache := range []bool{true, false} {
+ t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ UseNeighborCache: useNeighborCache,
+ }
+
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestRouteResolvedFields(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ localAddr tcpip.Address
+ remoteAddr tcpip.Address
+ immediatelyResolvable bool
+ expectedSuccess bool
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 immediately resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv4AllSystems,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
+ },
+ {
+ name: "IPv6 immediately resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv6AllNodesMulticastAddress,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
+ {
+ name: "IPv4 resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv6 resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
},
}
@@ -431,28 +559,618 @@ func TestGetLinkAddress(t *testing.T) {
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ var wantRouteInfo stack.RouteInfo
+ wantRouteInfo.LocalLinkAddress = linkAddr1
+ wantRouteInfo.LocalAddress = test.localAddr
+ wantRouteInfo.RemoteAddress = test.remoteAddr
+ wantRouteInfo.NetProto = test.netProto
+ wantRouteInfo.Loop = stack.PacketOut
+ wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
+
+ ch := make(chan stack.ResolvedFieldsResult, 1)
+
+ if !test.immediatelyResolvable {
+ wantUnresolvedRouteInfo := wantRouteInfo
+ wantUnresolvedRouteInfo.RemoteLinkAddress = ""
- for i := 0; i < 2; i++ {
- addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {})
- var want *tcpip.Error
- if i == 0 {
- want = tcpip.ErrWouldBlock
+ err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
- if err != want {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want)
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
}
- if i == 0 {
- <-ch
- continue
+ if !test.expectedSuccess {
+ return
}
- if addr != linkAddr2 {
- t.Fatalf("got addr = %s, want = %s", addr, linkAddr2)
+ // At this point the neighbor table should be populated so the route
+ // should be immediately resolvable.
+ }
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != nil {
+ t.Errorf("r.ResolvedFields(_): %s", err)
+ }
+ select {
+ case routeResolveRes := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected route to be immediately resolvable")
}
})
}
})
}
}
+
+func TestWritePacketsLinkResolution(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedWriteErr tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
+ }
+ defer serverEP.Close()
+
+ serverAddr := tcpip.FullAddress{Port: 1234}
+ if err := serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ data := []byte{1, 2}
+ var pkts stack.PacketBufferList
+ for _, d := range data {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: buffer.View([]byte{d}).ToVectorisedView(),
+ })
+ pkt.TransportProtocolNumber = udp.ProtocolNumber
+ length := uint16(pkt.Size())
+ udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: serverAddr.Port,
+ Length: length,
+ })
+ xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+
+ pkts.PushBack(pkt)
+ }
+
+ params := stack.NetworkHeaderParams{
+ Protocol: udp.ProtocolNumber,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+
+ if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
+ t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
+ } else if want := pkts.Len(); want != n {
+ t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
+ }
+
+ var writer bytes.Buffer
+ count := 0
+ for {
+ var rOpts tcpip.ReadOptions
+ res, err := serverEP.Read(&writer, rOpts)
+ if err != nil {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ // Should not have anymore bytes to read after we read the sent
+ // number of bytes.
+ if count == len(data) {
+ break
+ }
+
+ <-serverCH
+ continue
+ }
+
+ t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
+ }
+ count += res.Count
+ }
+
+ if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
+ t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
+ t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+type eventType int
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
+)
+
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ entry stack.NeighborEntry
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
+}
+
+var _ stack.NUDDispatcher = (*nudDispatcher)(nil)
+
+type nudDispatcher struct {
+ c chan eventInfo
+}
+
+func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) waitForEvent(want eventInfo) error {
+ if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ return nil
+}
+
+// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
+// that the neighbor used for a route is reachable.
+func TestTCPConfirmNeighborReachability(t *testing.T) {
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ neighborAddr tcpip.Address
+ getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{})
+ isHost1Listener bool
+ }{
+ {
+ name: "IPv4 active connection through neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv6 active connection through neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv4 active connection to neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv6 active connection to neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv4 passive connection to neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv6 passive connection to neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv4 passive connection through neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv6 passive connection through neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ nudDisp := nudDispatcher{
+ c: make(chan eventInfo, 3),
+ }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: clock,
+ UseNeighborCache: true,
+ }
+ host1StackOpts := stackOpts
+ host1StackOpts.NUDDisp = &nudDisp
+
+ host1Stack := stack.New(host1StackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ setupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ // Add a reachable dynamic entry to our neighbor table for the remote.
+ {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryAdded,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
+ }); err != nil {
+ t.Fatalf("error waiting for initial NUD event: %s", err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+
+ // Wait for the remote's neighbor entry to be stale before creating a
+ // TCP connection from host1 to some remote.
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err)
+ }
+ // The maximum reachable time for a neighbor is some maximum random factor
+ // applied to the base reachable time.
+ //
+ // See NUDConfigurations.BaseReachableTime for more information.
+ maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
+ clock.Advance(maxReachableTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for stale NUD event: %s", err)
+ }
+
+ listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
+ defer listenerEP.Close()
+ defer clientEP.Close()
+ listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
+ if err := listenerEP.Bind(listenerAddr); err != nil {
+ t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
+ }
+ if err := listenerEP.Listen(1); err != nil {
+ t.Fatalf("listenerEP.Listen(1): %s", err)
+ }
+ {
+ err := clientEP.Connect(listenerAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{})
+ }
+ }
+
+ // Wait for the TCP handshake to complete then make sure the neighbor is
+ // reachable without entering the probe state as TCP should provide NUD
+ // with confirmation that the neighbor is reachable (indicated by a
+ // successful 3-way handshake).
+ <-clientCH
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for delay NUD event: %s", err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+
+ // Wait for the neighbor to be stale again then send data to the remote.
+ //
+ // On successful transmission, the neighbor should become reachable
+ // without probing the neighbor as a TCP ACK would be received which is an
+ // indication of the neighbor being reachable.
+ clock.Advance(maxReachableTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for stale NUD event: %s", err)
+ }
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := clientEP.Write(&r, wOpts); err != nil {
+ t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for delay NUD event: %s", err)
+ }
+ if test.isHost1Listener {
+ // If host1 is not the client, host1 does not send any data so TCP
+ // has no way to know it is making forward progress. Because of this,
+ // TCP should not mark the route reachable and NUD should go through the
+ // probe state.
+ clock.Advance(nudConfigs.DelayFirstProbeTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for probe NUD event: %s", err)
+ }
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 3b13ba04d..ab67762ef 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
type ndpDispatcher struct{}
-func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
}
func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
@@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
Port: localPort,
},
}
- n, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ var r bytes.Reader
+ r.Reset(data)
+ n, err := sep.Write(&r, wopts)
if err != nil {
t.Fatalf("sep.Write(_, _): %s", err)
}
@@ -260,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
- } else if err != tcpip.ErrWouldBlock {
- t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
+ } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
}
})
}
@@ -320,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
}
- if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- })); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ {
+ err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ }))
+ if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
+ }
}
}
@@ -468,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
Addr: test.dstAddr,
Port: localPort,
}
- if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ {
+ err := connectingEndpoint.Connect(connectAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ }
}
if !test.expectAccept {
- if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ _, _, err := listeningEndpoint.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
return
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index ce7c16bd1..d685fdd36 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
- } else if err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
+ } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
}
})
}
@@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
- if n, err := wep.ep.Write(data, writeOpts); err != nil {
+ data := []byte{byte(i), 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := wep.ep.Write(&r, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
@@ -759,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
if err := ep.SetSockOpt(&removeOpt); err != nil {
t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
}
- if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
+ {
+ _, err := ep.Read(&buf, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
+ }
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index b222d2b05..9654c9527 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) {
linkEndpoint func() stack.LinkEndpoint
localAddr tcpip.Address
icmpBuf func(*testing.T) buffer.View
- expectedConnectErr *tcpip.Error
+ expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
}{
{
@@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv4.ProtocolNumber,
linkEndpoint: loopback.New,
icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
{
@@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
{
@@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: channelEPCheck,
},
{
@@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: channelEPCheck,
},
}
@@ -186,17 +186,22 @@ func TestLocalPing(t *testing.T) {
defer ep.Close()
connAddr := tcpip.FullAddress{Addr: test.localAddr}
- if err := ep.Connect(connAddr); err != test.expectedConnectErr {
- t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ {
+ err := ep.Connect(connAddr)
+ if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff)
+ }
}
if test.expectedConnectErr != nil {
return
}
- payload := tcpip.SlicePayload(test.icmpBuf(t))
+ payload := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(payload)
var wOpts tcpip.WriteOptions
- if n, err := ep.Write(payload, wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
} else if n != int64(len(payload)) {
t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
@@ -261,12 +266,12 @@ func TestLocalUDP(t *testing.T) {
subTests := []struct {
name string
addAddress bool
- expectedWriteErr *tcpip.Error
+ expectedWriteErr tcpip.Error
}{
{
name: "Unassigned local address",
addAddress: false,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
},
{
name: "Assigned local address",
@@ -329,12 +334,14 @@ func TestLocalUDP(t *testing.T) {
Port: 80,
}
- clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ clientPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(clientPayload)
wOpts := tcpip.WriteOptions{
To: &serverAddr,
}
- if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
} else if subTest.expectedWriteErr != nil {
// Nothing else to test if we expected not to be able to send the
@@ -376,12 +383,14 @@ func TestLocalUDP(t *testing.T) {
}
}
- serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ serverPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(serverPayload)
wOpts := tcpip.WriteOptions{
To: &clientAddr,
}
- if n, err := server.Write(serverPayload, wOpts); err != nil {
+ if n, err := server.Write(&r, wOpts); err != nil {
t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
} else if n != int64(len(serverPayload)) {
t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 256e19296..3cf05520d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -69,8 +69,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
+ mu sync.RWMutex `state:"nosave"`
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
@@ -85,7 +84,7 @@ type endpoint struct {
ops tcpip.SocketOptions
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.SetSendBufferSize(32*1024, false /* notify */)
+
+ // Override with stack defaults.
+ var ss tcpip.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
+ }
return ep, nil
}
@@ -119,7 +124,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -154,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -188,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
@@ -199,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
case stateConnected:
@@ -207,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
case stateBound:
if to == nil {
- return false, tcpip.ErrDestinationRequired
+ return false, &tcpip.ErrDestinationRequired{}
}
return false, nil
default:
- return false, tcpip.ErrInvalidEndpointState
+ return false, &tcpip.ErrInvalidEndpointState{}
}
e.mu.RUnlock()
@@ -236,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -257,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
to := opts.To
@@ -270,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
@@ -292,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
nicID := to.NIC
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
@@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
route = r
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
+ var err tcpip.Error
switch e.NetProto {
case header.IPv4ProtocolNumber:
err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
@@ -334,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.TTLOption:
e.mu.Lock()
@@ -351,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSize
- e.mu.Unlock()
- return v, nil
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
@@ -381,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -410,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
// Linux performs these basic checks.
if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
icmpv4.SetChecksum(0)
@@ -424,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -441,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
dataVV := data.ToVectorisedView()
@@ -456,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -465,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -485,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
if nicID != 0 && nicID != e.BindNICID {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nicID = e.BindNICID
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -535,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
e.shutdownFlags |= flags
if e.state != stateConnected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
if flags&tcpip.ShutdownRead != 0 {
@@ -565,31 +566,31 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen is not supported by UDP, it just fails.
-func (*endpoint) Listen(int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
- switch err {
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ switch err.(type) {
case nil:
return true, nil
- case tcpip.ErrPortInUse:
+ case *tcpip.ErrPortInUse:
return false, nil
default:
return false, err
@@ -599,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
return id, err
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.state != stateInitial {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -619,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
if len(addr.Addr) != 0 {
// A local address was specified, verify that it's valid.
if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
@@ -647,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -663,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -675,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return tcpip.FullAddress{
@@ -805,7 +806,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
func (*endpoint) Wait() {}
// LastError implements tcpip.Endpoint.LastError.
-func (*endpoint) LastError() *tcpip.Error {
+func (*endpoint) LastError() tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 9d263c0ec..c9fa9974a 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -69,12 +69,13 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
if e.state != stateBound && e.state != stateConnected {
return
}
- var err *tcpip.Error
+ var err tcpip.Error
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
if err != nil {
@@ -84,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.ID.LocalAddress = e.route.LocalAddress
} else if len(e.ID.LocalAddress) != 0 { // stateBound
if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3820e5dc7..47f7dd1cb 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if netProto != p.netProto() {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if netProto != p.netProto() {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
@@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int {
}
// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil.
-func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
switch p.number {
case ProtocolNumber4:
hdr := header.ICMPv4(v)
@@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c0d6fb442..73bb66830 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -79,24 +79,22 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
- boundNIC tcpip.NICID
+ mu sync.RWMutex `state:"nosave"`
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
// lastErrorMu protects lastError.
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
// NewEndpoint returns a new packet endpoint.
-func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
netProto: netProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- ep.sndBufSizeMax = ss.Default
+ ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -162,16 +159,16 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
if ep.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if ep.rcvClosed {
ep.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
ep.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -201,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
n, err := packet.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
}
-func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
-// connected, and this function always returnes tcpip.ErrNotSupported.
-func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- return tcpip.ErrNotSupported
+// connected, and this function always returnes *tcpip.ErrNotSupported.
+func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
-// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- return tcpip.ErrNotSupported
+// with Shutdown, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
-// Listen, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Listen(backlog int) *tcpip.Error {
- return tcpip.ErrNotSupported
+// Listen, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Listen(backlog int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
-// Accept, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+// Accept, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
// Bind implements tcpip.Endpoint.Bind.
-func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
// TODO(gvisor.dev/issue/173): Add Bind support.
// "By default, all packets of the specified protocol type are passed
@@ -277,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
// Even a connected socket doesn't return a remote address.
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness implements tcpip.Endpoint.Readiness.
@@ -306,38 +303,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
// used with SetSockOpt, and this function always returns
-// tcpip.ErrNotSupported.
-func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+// *tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := ep.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
- }
- if v > ss.Max {
- v = ss.Max
- }
- if v < ss.Min {
- v = ss.Min
- }
- ep.mu.Lock()
- ep.sndBufSizeMax = v
- ep.mu.Unlock()
- return nil
-
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -357,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (ep *endpoint) LastError() *tcpip.Error {
+func (ep *endpoint) LastError() tcpip.Error {
ep.lastErrorMu.Lock()
defer ep.lastErrorMu.Unlock()
@@ -371,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+func (ep *endpoint) UpdateLastError(err tcpip.Error) {
ep.lastErrorMu.Lock()
ep.lastError = err
ep.lastErrorMu.Unlock()
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
ep.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- v := ep.sndBufSizeMax
- ep.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
ep.rcvMu.Lock()
v := ep.rcvBufSizeMax
@@ -408,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e2fa96d17..ece662c0d 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -63,29 +63,11 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- // StackFromEnv is a stack used specifically for save/restore.
ep.stack = stack.StackFromEnv
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
// TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(*err)
+ panic(err)
}
}
-
-// saveLastError is invoked by stateify.
-func (ep *endpoint) saveLastError() string {
- if ep.lastError == nil {
- return ""
- }
-
- return ep.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (ep *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- ep.lastError = tcpip.StringToError(s)
-}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ae743f75e..9c9ccc0ff 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -76,12 +76,10 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route *stack.Route `state:"manual"`
@@ -95,13 +93,13 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
e := &endpoint{
@@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
associated: associated,
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
e.ops.SetHeaderIncluded(!associated)
+ e.ops.SetSendBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- e.sndBufSizeMax = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
return nil, err
}
@@ -159,7 +157,7 @@ func (e *endpoint) Close() {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -191,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -227,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
}
// Write implements tcpip.Endpoint.Write.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// We can create, but not write to, unassociated IPv6 endpoints.
if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
}
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -267,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.closed {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
- payloadBytes, err := p.FullPayload()
- if err != nil {
- return 0, err
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
// If this is an unassociated socket and callee provided a nonzero
@@ -290,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
dstAddr := ip.DestinationAddress()
// Update dstAddr with the address in the IP header, unless
@@ -311,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- return 0, tcpip.ErrDestinationRequired
+ return 0, &tcpip.ErrDestinationRequired{}
}
return e.finishWrite(payloadBytes, e.route)
@@ -321,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
// Find the route to the destination. If BindAddress is 0,
@@ -338,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) {
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
@@ -365,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect implements tcpip.Endpoint.Connect.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
- return tcpip.ErrAddressFamilyNotSupported
+ return &tcpip.ErrAddressFamilyNotSupported{}
}
e.mu.Lock()
defer e.mu.Unlock()
if e.closed {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nic := addr.NIC
@@ -395,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
} else if addr.NIC != e.BindNICID {
// We're bound and addr specifies a NIC. They must be
// the same.
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
}
@@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
route.Release()
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.RegisterNICID = nic
}
- // Save the route we've connected via.
+ if e.route != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.route.Release()
+ }
e.route = route
e.connected = true
@@ -423,42 +424,42 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
if !e.connected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
return nil
}
// Listen implements tcpip.Endpoint.Listen.
-func (*endpoint) Listen(backlog int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(backlog int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept implements tcpip.Endpoint.Accept.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
// Bind implements tcpip.Endpoint.Bind.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
- return tcpip.ErrBadLocalAddress
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
+ return &tcpip.ErrBadLocalAddress{}
}
if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.RegisterNICID = addr.NIC
e.BindNICID = addr.NIC
}
@@ -470,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
// Even a connected socket doesn't return a remote address.
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness implements tcpip.Endpoint.Readiness.
@@ -498,37 +499,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := e.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
- }
- if v > ss.Max {
- v = ss.Max
- }
- if v < ss.Min {
- v = ss.Min
- }
- e.mu.Lock()
- e.sndBufSizeMax = v
- e.mu.Unlock()
- return nil
-
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -548,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSizeMax
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
v := e.rcvBufSizeMax
@@ -583,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -703,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
func (*endpoint) Wait() {}
// LastError implements tcpip.Endpoint.LastError.
-func (*endpoint) LastError() *tcpip.Error {
+func (*endpoint) LastError() tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 4a7e1c039..263ec5146 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -69,10 +69,11 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
// If the endpoint is connected, re-connect.
if e.connected {
- var err *tcpip.Error
+ var err tcpip.Error
// TODO(gvisor.dev/issue/4906): Properly restore the route with the right
// remote address. We used to pass e.remote.RemoteAddress which was
// effectively the empty address but since moving e.route to hold a pointer
@@ -88,12 +89,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is bound, re-bind.
if e.bound {
if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index f30aa2a4a..e393b993d 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -25,11 +25,11 @@ import (
type EndpointFactory struct{}
// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
-func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
-func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 7e81203ba..fcdd032c5 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -99,7 +99,6 @@ go_test(
"//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6921de0f1..842c1622b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
ep.mu.Unlock()
ep.Close()
- return nil, tcpip.ErrConnectionAborted
+ return nil, &tcpip.ErrConnectionAborted{}
}
l.addPendingEndpoint(ep)
@@ -281,14 +281,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
l.removePendingEndpoint(ep)
- return nil, tcpip.ErrConnectionAborted
+ return nil, &tcpip.ErrConnectionAborted{}
}
deferAccept = l.listenEP.deferAccept
}
// Register new endpoint so that packets are routed to it.
- if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
@@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
defer s.decRef()
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
@@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// and needs to handle it.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
}
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
n.mu.Unlock()
n.Close()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index f711cd4df..4695b66d6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool {
// synSentState handles a segment received when the TCP 3-way handshake is in
// the SYN-SENT state.
-func (h *handshake) synSentState(s *segment) *tcpip.Error {
+func (h *handshake) synSentState(s *segment) tcpip.Error {
// RFC 793, page 37, states that in the SYN-SENT state, a reset is
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(header.TCPFlagRst) {
@@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.ep.workerCleanup = true
// Although the RFC above calls out ECONNRESET, Linux actually returns
// ECONNREFUSED here so we do as well.
- return tcpip.ErrConnectionRefused
+ return &tcpip.ErrConnectionRefused{}
}
return nil
}
@@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// synRcvdState handles a segment received when the TCP 3-way handshake is in
// the SYN-RCVD state.
-func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if s.flagIsSet(header.TCPFlagRst) {
// RFC 793, page 37, states that in the SYN-RCVD state, a reset
// is acceptable if the sequence number is in the window.
if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
- return tcpip.ErrConnectionRefused
+ return &tcpip.ErrConnectionRefused{}
}
return nil
}
@@ -349,7 +349,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
if !h.active {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
h.resetState()
@@ -412,7 +412,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
-func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+func (h *handshake) handleSegment(s *segment) tcpip.Error {
h.sndWnd = s.window
if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
h.sndWnd <<= uint8(h.sndWndScale)
@@ -429,7 +429,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error {
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
-func (h *handshake) processSegments() *tcpip.Error {
+func (h *handshake) processSegments() tcpip.Error {
for i := 0; i < maxSegmentsPerWake; i++ {
s := h.ep.segmentQueue.dequeue()
if s == nil {
@@ -505,7 +505,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
-func (h *handshake) complete() *tcpip.Error {
+func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
resendWaker := sleep.Waker{}
@@ -555,7 +555,7 @@ func (h *handshake) complete() *tcpip.Error {
case wakerForNotification:
n := h.ep.fetchNotifications()
if (n&notifyClose)|(n&notifyAbort) != 0 {
- return tcpip.ErrAborted
+ return &tcpip.ErrAborted{}
}
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
@@ -593,19 +593,19 @@ type backoffTimer struct {
t *time.Timer
}
-func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
if timeout > maxTimeout {
- return nil, tcpip.ErrTimeout
+ return nil, &tcpip.ErrTimeout{}
}
bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
bt.t = time.AfterFunc(timeout, f)
return bt, nil
}
-func (bt *backoffTimer) reset() *tcpip.Error {
+func (bt *backoffTimer) reset() tcpip.Error {
bt.timeout *= 2
if bt.timeout > MaxRTO {
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
bt.t.Reset(bt.timeout)
return nil
@@ -706,7 +706,7 @@ type tcpFields struct {
txHash uint32
}
-func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error {
tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
@@ -716,7 +716,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error {
tf.txHash = e.txHash
if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
@@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
}
}
-func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
// We need to shallow clone the VectorisedView here as ReadToView will
// split the VectorisedView and Trim underlying views as it splits. Not
// doing the clone here will cause the underlying views of data itself
@@ -803,7 +803,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
optLen := len(tf.opts)
if tf.rcvWnd > math.MaxUint16 {
tf.rcvWnd = math.MaxUint16
@@ -875,7 +875,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
}
// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
var sackBlocks []header.SACKBlock
if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
@@ -895,55 +895,60 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
return err
}
-func (e *endpoint) handleWrite() *tcpip.Error {
- // Move packets from send queue to send list. The queue is accessible
- // from other goroutines and protected by the send mutex, while the send
- // list is only accessible from the handler goroutine, so it needs no
- // mutexes.
+func (e *endpoint) handleWrite() {
e.sndBufMu.Lock()
+ next := e.drainSendQueueLocked()
+ e.sndBufMu.Unlock()
+
+ e.sendData(next)
+}
+// Move packets from send queue to send list.
+//
+// Precondition: e.sndBufMu must be locked.
+func (e *endpoint) drainSendQueueLocked() *segment {
first := e.sndQueue.Front()
if first != nil {
e.snd.writeList.PushBackList(&e.sndQueue)
e.sndBufInQueue = 0
}
+ return first
+}
- e.sndBufMu.Unlock()
-
+// Precondition: e.mu must be locked.
+func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
- e.snd.writeNext = first
+ e.snd.writeNext = next
}
// Push out any new packets.
e.snd.sendData()
-
- return nil
}
-func (e *endpoint) handleClose() *tcpip.Error {
+func (e *endpoint) handleClose() {
if !e.EndpointState().connected() {
- return nil
+ return
}
// Drain the send queue.
e.handleWrite()
// Mark send side as closed.
e.snd.closed = true
-
- return nil
}
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
// method must only be called from the protocol goroutine.
-func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+func (e *endpoint) resetConnectionLocked(err tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
e.hardError = err
- if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ switch err.(type) {
+ case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout:
+ default:
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
// which can cause sndNxt to be outside the acceptable window on the
@@ -1053,7 +1058,7 @@ func (e *endpoint) drainClosingSegmentQueue() {
}
}
-func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
@@ -1081,7 +1086,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.hardError = tcpip.ErrAborted
+ e.hardError = &tcpip.ErrAborted{}
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1094,14 +1099,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// handleSegment is invoked from the processor goroutine
// rather than the worker goroutine.
e.notifyProtocolGoroutine(notifyResetByPeer)
- return false, tcpip.ErrConnectionReset
+ return false, &tcpip.ErrConnectionReset{}
}
}
return true, nil
}
// handleSegments processes all inbound segments.
-func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
+func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
if e.EndpointState().closed() {
@@ -1148,7 +1153,7 @@ func (e *endpoint) probeSegment() {
// handleSegment handles a given segment and notifies the worker goroutine if
// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
// Invoke the tcp probe if installed. The tcp probe function will update
// the TCPEndpointState after the segment is processed.
defer e.probeSegment()
@@ -1222,7 +1227,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
-func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
userTimeout := e.userTimeout
e.keepalive.Lock()
@@ -1236,13 +1241,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -1286,7 +1291,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1332,6 +1337,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
+ // Reaching this point means that we successfully completed the 3-way
+ // handshake with our peer.
+ //
+ // Completing the 3-way handshake is an indication that the route is valid
+ // and the remote is reachable as the only way we can complete a handshake
+ // is if our SYN reached the remote and their ACK reached us.
+ e.route.ConfirmReachable()
+
drained := e.drainDone != nil
if drained {
close(e.drainDone)
@@ -1344,19 +1357,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// wakes up.
funcs := []struct {
w *sleep.Waker
- f func() *tcpip.Error
+ f func() tcpip.Error
}{
{
w: &e.sndWaker,
- f: e.handleWrite,
+ f: func() tcpip.Error {
+ e.handleWrite()
+ return nil
+ },
},
{
w: &e.sndCloseWaker,
- f: e.handleClose,
+ f: func() tcpip.Error {
+ e.handleClose()
+ return nil
+ },
},
{
w: &closeWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
// This means the socket is being closed due
// to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
@@ -1367,10 +1386,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.snd.resendWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
if !e.snd.retransmitTimerExpired() {
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
return nil
},
@@ -1381,7 +1400,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.newSegmentWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
return e.handleSegments(false /* fastPath */)
},
},
@@ -1391,7 +1410,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.notificationWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
n := e.fetchNotifications()
if n&notifyNonZeroReceiveWindow != 0 {
e.rcv.nonZeroWindow()
@@ -1408,11 +1427,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyReset != 0 || n&notifyAbort != 0 {
- return tcpip.ErrConnectionAborted
+ return &tcpip.ErrConnectionAborted{}
}
if n&notifyResetByPeer != 0 {
- return tcpip.ErrConnectionReset
+ return &tcpip.ErrConnectionReset{}
}
if n&notifyClose != 0 && closeTimer == nil {
@@ -1491,7 +1510,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- cleanupOnError := func(err *tcpip.Error) {
+ cleanupOnError := func(err tcpip.Error) {
e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 1d1b01a6c..2d90246e4 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -15,11 +15,11 @@
package tcp_test
import (
+ "strings"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
// Start connection attempt, it must fail.
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrNoRoute {
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
}
@@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) {
t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
}
+ var r strings.Reader
data := "Don't panic"
- nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ r.Reset(data)
+ nep.Write(&r, tcpip.WriteOptions{})
b = c.GetPacket()
tcp = header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
@@ -523,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) {
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
_, _, err := c.EP.Accept(&addr)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -547,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
@@ -611,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -633,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ea509ac73..6e4e26c39 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -386,12 +386,12 @@ type endpoint struct {
// hardError is meaningful only when state is stateError. It stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state. hardError is protected by endpoint mu.
- hardError *tcpip.Error `state:".(string)"`
+ hardError tcpip.Error
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// rcvReadMu synchronizes calls to Read.
//
@@ -557,7 +557,6 @@ type endpoint struct {
// When the send side is closed, the protocol goroutine is notified via
// sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex `state:"nosave"`
- sndBufSize int
sndBufUsed int
sndClosed bool
sndBufInQueue seqnum.Size
@@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
- sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
keepalive: keepalive{
// Linux defaults.
@@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetQuickAck(true)
+ e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- e.sndBufSize = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs tcpip.TCPReceiveBufferSizeRangeOption
@@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
- if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ sndBufSize := e.getSendBufferSize()
+ if e.sndClosed || e.sndBufUsed < sndBufSize {
result |= waiter.EventOut
}
e.sndBufMu.Unlock()
@@ -1059,7 +1059,7 @@ func (e *endpoint) Close() {
if isResetState {
// Close the endpoint without doing full shutdown and
// send a RST.
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
e.closeNoShutdownLocked()
// Wake up worker to close the endpoint.
@@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1293,14 +1293,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Preconditions: e.mu must be held to call this function.
-func (e *endpoint) hardErrorLocked() *tcpip.Error {
+func (e *endpoint) hardErrorLocked() tcpip.Error {
err := e.hardError
e.hardError = nil
return err
}
// Preconditions: e.mu must be held to call this function.
-func (e *endpoint) lastErrorLocked() *tcpip.Error {
+func (e *endpoint) lastErrorLocked() tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1309,7 +1309,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error {
}
// LastError implements tcpip.Endpoint.LastError.
-func (e *endpoint) LastError() *tcpip.Error {
+func (e *endpoint) LastError() tcpip.Error {
e.LockUser()
defer e.UnlockUser()
if err := e.hardErrorLocked(); err != nil {
@@ -1319,7 +1319,7 @@ func (e *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.LockUser()
e.lastErrorMu.Lock()
e.lastError = err
@@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvReadMu.Lock()
defer e.rcvReadMu.Unlock()
@@ -1337,7 +1337,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// can remove segments from the list through commitRead().
first, last, serr := e.startRead()
if serr != nil {
- if serr == tcpip.ErrClosedForReceive {
+ if _, ok := serr.(*tcpip.ErrClosedForReceive); ok {
e.stats.ReadErrors.ReadClosed.Increment()
}
return tcpip.ReadResult{}, serr
@@ -1377,7 +1377,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// If something is read, we must report it. Report error when nothing is read.
if done == 0 && err != nil {
- return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{}
}
return tcpip.ReadResult{
Count: done,
@@ -1389,7 +1389,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// inclusive range of segments that can be read.
//
// Precondition: e.rcvReadMu must be held.
-func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
+func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1398,7 +1398,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1414,17 +1414,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
if err := e.hardErrorLocked(); err != nil {
return nil, nil, err
}
- return nil, nil, tcpip.ErrClosedForReceive
+ return nil, nil, &tcpip.ErrClosedForReceive{}
}
e.stats.ReadErrors.NotConnected.Increment()
- return nil, nil, tcpip.ErrNotConnected
+ return nil, nil, &tcpip.ErrNotConnected{}
}
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return nil, nil, tcpip.ErrClosedForReceive
+ return nil, nil, &tcpip.ErrClosedForReceive{}
}
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
return e.rcvList.Front(), e.rcvList.Back(), nil
@@ -1476,106 +1476,117 @@ func (e *endpoint) commitRead(done int) *segment {
// moment. If the endpoint is not writable then it returns an error
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
-func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
if err := e.hardErrorLocked(); err != nil {
return 0, err
}
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
case !s.connecting() && !s.connected():
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
case s.connecting():
// As per RFC793, page 56, a send request arriving when in connecting
// state, can be queued to be completed after the state becomes
// connected. Return an error code for the caller of endpoint Write to
// try again, until the connection handshake is complete.
- return 0, tcpip.ErrWouldBlock
+ return 0, &tcpip.ErrWouldBlock{}
}
// Check if the connection has already been closed for sends.
if e.sndClosed {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
- avail := e.sndBufSize - e.sndBufUsed
+ sndBufSize := e.getSendBufferSize()
+ avail := sndBufSize - e.sndBufUsed
if avail <= 0 {
- return 0, tcpip.ErrWouldBlock
+ return 0, &tcpip.ErrWouldBlock{}
}
return avail, nil
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
e.LockUser()
- e.sndBufMu.Lock()
-
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, err
- }
-
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- }
-
- // Fetch data.
- v, perr := p.Payload(avail)
- if perr != nil || len(v) == 0 {
- // Note that perr may be nil if len(v) == 0.
- if opts.Atomic {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- }
- return 0, perr
- }
+ defer e.UnlockUser()
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- e.LockUser()
+ nextSeg, n, err := func() (*segment, int, tcpip.Error) {
e.sndBufMu.Lock()
+ defer e.sndBufMu.Unlock()
+
avail, err := e.isEndpointWritableLocked()
if err != nil {
- e.sndBufMu.Unlock()
- e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
- return 0, err
+ return nil, 0, err
}
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ v, err := func() ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ defer e.sndBufMu.Lock()
+
+ e.UnlockUser()
+ defer e.LockUser()
+ }
+
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ if _, err := io.ReadFull(p, v); err != nil {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v, nil
+ }()
+ if len(v) == 0 || err != nil {
+ return nil, 0, err
+ }
+
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
- }
- // Add data to the send queue.
- s := newOutgoingSegment(e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
- // Do the work inline.
- e.handleWrite()
- e.UnlockUser()
- return int64(len(v)), nil
+ return e.drainSendQueueLocked(), len(v), nil
+ }()
+ if err != nil {
+ return 0, err
+ }
+ e.sendData(nextSeg)
+ return int64(n), nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1682,8 +1693,16 @@ func (e *endpoint) OnCorkOptionSet(v bool) {
}
}
+func (e *endpoint) getSendBufferSize() int {
+ sndBufSize, err := e.ops.GetSendBufferSize()
+ if err != nil {
+ panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err))
+ }
+ return int(sndBufSize)
+}
+
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
const inetECNMask = 3
@@ -1711,7 +1730,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.MaxSegOption:
userMSS := v
if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
e.LockUser()
e.userMSS = uint16(userMSS)
@@ -1722,7 +1741,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Return not supported if attempting to set this option to
// anything other than path MTU discovery disabled.
if v != tcpip.PMTUDiscoveryDont {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
case tcpip.ReceiveBufferSizeOption:
@@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.rcvListMu.Unlock()
e.UnlockUser()
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss tcpip.TCPSendBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
- panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
- }
-
- if v > ss.Max {
- v = ss.Max
- }
-
- if v < math.MaxInt32/SegOverheadFactor {
- v *= SegOverheadFactor
- if v < ss.Min {
- v = ss.Min
- }
- } else {
- v = math.MaxInt32
- }
-
- e.sndBufMu.Lock()
- e.sndBufSize = v
- e.sndBufMu.Unlock()
-
case tcpip.TTLOption:
e.LockUser()
e.ttl = uint8(v)
@@ -1807,7 +1801,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.TCPSynCountOption:
if v < 1 || v > 255 {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
e.LockUser()
e.maxSynRetries = uint8(v)
@@ -1823,7 +1817,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
e.UnlockUser()
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
}
var rs tcpip.TCPReceiveBufferSizeRangeOption
@@ -1844,7 +1838,7 @@ func (e *endpoint) HasNIC(id int32) bool {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
@@ -1890,7 +1884,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// Linux returns ENOENT when an invalid congestion
// control algorithm is specified.
- return tcpip.ErrNoSuchFile
+ return &tcpip.ErrNoSuchFile{}
case *tcpip.TCPLingerTimeoutOption:
e.LockUser()
@@ -1933,13 +1927,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
// readyReceiveSize returns the number of bytes ready to be received.
-func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+func (e *endpoint) readyReceiveSize() (int, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
// The endpoint cannot be in listen state.
if e.EndpointState() == StateListen {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
e.rcvListMu.Lock()
@@ -1949,7 +1943,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.KeepaliveCountOption:
e.keepalive.Lock()
@@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
- case tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- v := e.sndBufSize
- e.sndBufMu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvListMu.Lock()
v := e.rcvBufSize
@@ -2019,24 +2007,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return 1, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
+ info := tcpip.TCPInfoOption{}
+ e.LockUser()
+ snd := e.snd
+ if snd != nil {
+ // We do not calculate RTT before sending the data packets. If
+ // the connection did not send and receive data, then RTT will
+ // be zero.
+ snd.rtt.Lock()
+ info.RTT = snd.rtt.srtt
+ info.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+
+ info.RTO = snd.rto
+ info.CcState = snd.state
+ info.SndSsthresh = uint32(snd.sndSsthresh)
+ info.SndCwnd = uint32(snd.sndCwnd)
+ info.ReorderSeen = snd.rc.reorderSeen
}
+ e.UnlockUser()
+ return info
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.TCPInfoOption:
- *o = tcpip.TCPInfoOption{}
- e.LockUser()
- snd := e.snd
- e.UnlockUser()
- if snd != nil {
- snd.rtt.Lock()
- o.RTT = snd.rtt.srtt
- o.RTTVar = snd.rtt.rttvar
- snd.rtt.Unlock()
- }
+ *o = e.getTCPInfo()
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
@@ -2082,14 +2084,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
return nil
}
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -2098,18 +2100,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect connects the endpoint to its peer.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
err := e.connect(addr, true, true)
- if err != nil && !err.IgnoreStats() {
- // Connect failed. Let's wake up any waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
+ if err != nil {
+ if !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
}
return err
}
@@ -2120,7 +2124,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// created (so no new handshaking is done); for stack-accepted connections not
// yet accepted by the app, they are restored without running the main goroutine
// here.
-func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
@@ -2139,7 +2143,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return nil
}
// Otherwise return that it's already connected.
- return tcpip.ErrAlreadyConnected
+ return &tcpip.ErrAlreadyConnected{}
}
nicID := addr.NIC
@@ -2152,7 +2156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if nicID != 0 && nicID != e.boundNICID {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
nicID = e.boundNICID
@@ -2164,16 +2168,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
case StateConnecting, StateSynSent, StateSynRecv:
// A connection request has already been issued but hasn't completed
// yet.
- return tcpip.ErrAlreadyConnecting
+ return &tcpip.ErrAlreadyConnecting{}
case StateError:
if err := e.hardErrorLocked(); err != nil {
return err
}
- return tcpip.ErrConnectionAborted
+ return &tcpip.ErrConnectionAborted{}
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// Find a route to the desired destination.
@@ -2190,7 +2194,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2229,12 +2233,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
- if err != tcpip.ErrPortInUse || !reuse {
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
transEPID := e.ID
@@ -2278,9 +2282,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
- if err == tcpip.ErrPortInUse {
+ if _, ok := err.(*tcpip.ErrPortInUse); ok {
return false, nil
}
return false, err
@@ -2335,23 +2339,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
- return tcpip.ErrConnectStarted
+ return &tcpip.ErrConnectStarted{}
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
return e.shutdownLocked(flags)
}
-func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.shutdownFlags |= flags
switch {
case e.EndpointState().connected():
@@ -2366,7 +2370,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
// Wake up worker to terminate loop.
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
@@ -2380,7 +2384,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// Already closed.
e.sndBufMu.Unlock()
if e.EndpointState() == StateTimeWait {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
return nil
}
@@ -2413,22 +2417,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
return nil
default:
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (e *endpoint) Listen(backlog int) tcpip.Error {
err := e.listen(backlog)
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
+ if err != nil {
+ if !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
}
return err
}
-func (e *endpoint) listen(backlog int) *tcpip.Error {
+func (e *endpoint) listen(backlog int) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
@@ -2446,7 +2452,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// Adjust the size of the channel iff we can fix
// existing pending connections into the new one.
if len(e.acceptedChan) > backlog {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
if cap(e.acceptedChan) == backlog {
return nil
@@ -2478,11 +2484,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// Endpoint must be bound before it can transition to listen mode.
if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
@@ -2518,7 +2524,7 @@ func (e *endpoint) startAcceptedLoop() {
// to an endpoint previously set to listen mode.
//
// addr if not-nil will contain the peer address of the returned endpoint.
-func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2527,7 +2533,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
if rcvClosed || e.EndpointState() != StateListen {
- return nil, nil, tcpip.ErrInvalidEndpointState
+ return nil, nil, &tcpip.ErrInvalidEndpointState{}
}
// Get the new accepted endpoint.
@@ -2538,7 +2544,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
case n = <-e.acceptedChan:
e.acceptCond.Signal()
default:
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
if peerAddr != nil {
*peerAddr = n.getRemoteAddress()
@@ -2547,19 +2553,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
}
// Bind binds the endpoint to a specific local port and optionally address.
-func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
return e.bindLocked(addr)
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
if e.EndpointState() != StateInitial {
- return tcpip.ErrAlreadyBound
+ return &tcpip.ErrAlreadyBound{}
}
e.BindAddr = addr.Addr
@@ -2587,7 +2593,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
if len(addr.Addr) != 0 {
nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
e.ID.LocalAddress = addr.Addr
}
@@ -2604,7 +2610,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2628,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2640,12 +2646,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
if !e.EndpointState().connected() {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return e.getRemoteAddress(), nil
@@ -2677,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
-func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
// Update last error first.
e.lastErrorMu.Lock()
e.lastError = err
@@ -2726,26 +2732,27 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlAddressUnreachable:
- e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
// updateSndBufferUsage is called by the protocol goroutine when room opens up
// in the send buffer. The number of newly available bytes is v.
func (e *endpoint) updateSndBufferUsage(v int) {
+ sendBufferSize := e.getSendBufferSize()
e.sndBufMu.Lock()
- notify := e.sndBufUsed >= e.sndBufSize>>1
+ notify := e.sndBufUsed >= sendBufferSize>>1
e.sndBufUsed -= v
- // We only notify when there is half the sndBufSize available after
+ // We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ notify = notify && e.sndBufUsed < int(sendBufferSize)>>1
e.sndBufMu.Unlock()
if notify {
@@ -2957,8 +2964,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
// Copy endpoint send state.
+ sndBufSize := e.getSendBufferSize()
e.sndBufMu.Lock()
- s.SndBufSize = e.sndBufSize
+ s.SndBufSize = sndBufSize
s.SndBufUsed = e.sndBufUsed
s.SndClosed = e.sndClosed
s.SndBufInQueue = e.sndBufInQueue
@@ -3023,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
rc := &e.snd.rc
s.Sender.RACKState = stack.TCPRACKState{
- XmitTime: rc.xmitTime,
- EndSequence: rc.endSequence,
- FACK: rc.fack,
- RTT: rc.rtt,
- Reord: rc.reorderSeen,
- DSACKSeen: rc.dsackSeen,
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ Reord: rc.reorderSeen,
+ DSACKSeen: rc.dsackSeen,
+ ReoWnd: rc.reoWnd,
+ ReoWndIncr: rc.reoWndIncr,
+ ReoWndPersist: rc.reoWndPersist,
+ RTTSeq: rc.rttSeq,
}
return s
}
@@ -3103,3 +3115,17 @@ func (e *endpoint) Wait() {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// GetTCPSendBufferLimits is used to get send buffer size limits for TCP.
+func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption {
+ var ss tcpip.TCPSendBufferSizeRangeOption
+ if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err))
+ }
+
+ return tcpip.SendBufferSizeOption{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index ba67176b5..c21dbc682 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -55,9 +55,11 @@ func (e *endpoint) beforeSave() {
case epState.connected() || epState.handshake():
if !e.route.HasSaveRestoreCapability() {
if !e.route.HasDisconncetOkCapability() {
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
+ panic(&tcpip.ErrSaveRejection{
+ Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort),
+ })
}
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
e.mu.Unlock()
e.Close()
e.mu.Lock()
@@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss tcpip.TCPSendBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ sendBufferSize := e.getSendBufferSize()
+ if sendBufferSize < ss.Min || sendBufferSize > ss.Max {
+ panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max))
}
}
@@ -228,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
e.mu.Lock()
@@ -265,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
+ err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
connectingLoading.Done()
@@ -292,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
-// saveLastError is invoked by stateify.
-func (e *endpoint) saveLastError() string {
- if e.lastError == nil {
- return ""
- }
-
- return e.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (e *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- e.lastError = tcpip.StringToError(s)
-}
-
// saveRecentTSTime is invoked by stateify.
func (e *endpoint) saveRecentTSTime() unixTime {
return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
@@ -320,24 +308,6 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
e.recentTSTime = time.Unix(unix.second, unix.nano)
}
-// saveHardError is invoked by stateify.
-func (e *endpoint) saveHardError() string {
- if e.hardError == nil {
- return ""
- }
-
- return e.hardError.String()
-}
-
-// loadHardError is invoked by stateify.
-func (e *endpoint) loadHardError(s string) {
- if s == "" {
- return
- }
-
- e.hardError = tcpip.StringToError(s)
-}
-
// saveMeasureTime is invoked by stateify.
func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 596178625..2f9fe7ee0 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// CreateEndpoint creates a TCP endpoint for the connection request, performing
// the 3-way handshake in the process.
-func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.segment == nil {
- return nil, tcpip.ErrInvalidEndpointState
+ return nil, &tcpip.ErrInvalidEndpointState{}
}
f := r.forwarder
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 1720370c9..04012cd40 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
@@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int {
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
@@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
-func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPSACKEnabled:
p.mu.Lock()
@@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPSendBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.sendBufferSize = *v
@@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPReceiveBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.recvBufferSize = *v
@@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
}
// linux returns ENOENT when an invalid congestion control
// is specified.
- return tcpip.ErrNoSuchFile
+ return &tcpip.ErrNoSuchFile{}
case *tcpip.TCPModerateReceiveBufferOption:
p.mu.Lock()
@@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPTimeWaitReuseOption:
if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.timeWaitReuse = *v
@@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPSynRetriesOption:
if *v < 1 || *v > 255 {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.synRetries = uint8(*v)
@@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPSACKEnabled:
p.mu.RLock()
@@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 307bacca5..d85cb405a 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -22,12 +22,21 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
-// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as
-// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.
-// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is
-// 1, PTO is inflated by WCDelAckT time to compensate for a potential long
-// delayed ACK timer at the receiver.
-const wcDelayedACKTimeout = 200 * time.Millisecond
+const (
+ // wcDelayedACKTimeout is the recommended maximum delayed ACK timer
+ // value as defined in the RFC. It stands for worst case delayed ACK
+ // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by
+ // WCDelAckT time to compensate for a potential long delayed ACK timer
+ // at the receiver.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.
+ wcDelayedACKTimeout = 200 * time.Millisecond
+
+ // tcpRACKRecoveryThreshold is the number of loss recoveries for which
+ // the reorder window is inflated and after that the reorder window is
+ // reset to its initial value of minRTT/4.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2.
+ tcpRACKRecoveryThreshold = 16
+)
// RACK is a loss detection algorithm used in TCP to detect packet loss and
// reordering using transmission timestamp of the packets instead of packet or
@@ -44,6 +53,11 @@ type rackControl struct {
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
+ // exitedRecovery indicates if the connection is exiting loss recovery.
+ // This flag is set if the sender is leaving the recovery after
+ // receiving an ACK and is reset during updating of reorder window.
+ exitedRecovery bool
+
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
@@ -51,15 +65,30 @@ type rackControl struct {
// minRTT is the estimated minimum RTT of the connection.
minRTT time.Duration
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // reoWnd is the reordering window time used for recording packet
+ // transmission times. It is used to defer the moment at which RACK
+ // marks a packet lost.
+ reoWnd time.Duration
+
+ // reoWndIncr is the multiplier applied to adjust reorder window.
+ reoWndIncr uint8
+
+ // reoWndPersist is the number of loss recoveries before resetting
+ // reorder window.
+ reoWndPersist int8
+
// rtt is the RTT of the most recently delivered packet on the
// connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
rtt time.Duration
- // reorderSeen indicates if reordering has been detected on this
- // connection.
- reorderSeen bool
+ // rttSeq is the SND.NXT when rtt is updated.
+ rttSeq seqnum.Value
// xmitTime is the latest transmission timestamp of rackControl.seg.
xmitTime time.Time `state:".(unixTime)"`
@@ -75,29 +104,36 @@ type rackControl struct {
// tlpHighRxt the value of sender.sndNxt at the time of sending
// a TLP retransmission.
tlpHighRxt seqnum.Value
+
+ // snd is a reference to the sender.
+ snd *sender
}
// init initializes RACK specific fields.
-func (rc *rackControl) init() {
+func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
+ rc.fack = iss
+ rc.reoWndIncr = 1
+ rc.snd = snd
rc.probeTimer.init(&rc.probeWaker)
}
// update will update the RACK related fields when an ACK has been received.
-// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
-func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
+func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := time.Now().Sub(seg.xmitTime)
+ tsOffset := rc.snd.ep.tsOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
- // 1. When Timestamping option is available, if the TSVal is less than the
- // transmit time of the most recent retransmitted packet.
+ // 1. When Timestamping option is available, if the TSVal is less than
+ // the transmit time of the most recent retransmitted packet.
// 2. When RTT calculated for the packet is less than the smoothed RTT
// for the connection.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
return
}
}
@@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) {
}
}
-// setDSACKSeen updates rack control if duplicate SACK is seen by the connection.
-func (rc *rackControl) setDSACKSeen() {
- rc.dsackSeen = true
+func (rc *rackControl) setDSACKSeen(dsackSeen bool) {
+ rc.dsackSeen = dsackSeen
}
// shouldSchedulePTO dictates whether we should schedule a PTO or not.
@@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool {
// The connection supports SACK.
s.ep.sackPermitted &&
// The connection is not in loss recovery.
- (s.state != RTORecovery && s.state != SACKRecovery) &&
+ (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) &&
// The connection has no SACKed sequences in the SACK scoreboard.
s.ep.scoreboard.Sacked() == 0
}
@@ -193,7 +228,7 @@ func (s *sender) schedulePTO() {
// probeTimerExpired is the same as TLP_send_probe() as defined in
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2.
-func (s *sender) probeTimerExpired() *tcpip.Error {
+func (s *sender) probeTimerExpired() tcpip.Error {
if !s.rc.probeTimer.checkExpiration() {
return nil
}
@@ -272,3 +307,82 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
}
}
}
+
+// updateRACKReorderWindow updates the reorder window.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 4: Update RACK reordering window
+// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as
+// an allowance for settling time before marking a packet lost. RACK starts
+// initially with a conservative window of min_RTT/4. If no reordering has
+// been observed RACK uses reo_wnd of zero during loss recovery, in order to
+// retransmit quickly, or when the number of DUPACKs exceeds the classic
+// DUPACKthreshold.
+func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
+ dsackSeen := rc.dsackSeen
+ snd := rc.snd
+
+ // React to DSACK once per round trip.
+ // If SND.UNA < RACK.rtt_seq:
+ // RACK.dsack = false
+ if snd.sndUna.LessThan(rc.rttSeq) {
+ dsackSeen = false
+ }
+
+ // If RACK.dsack:
+ // RACK.reo_wnd_incr += 1
+ // RACK.dsack = false
+ // RACK.rtt_seq = SND.NXT
+ // RACK.reo_wnd_persist = 16
+ if dsackSeen {
+ rc.reoWndIncr++
+ dsackSeen = false
+ rc.rttSeq = snd.sndNxt
+ rc.reoWndPersist = tcpRACKRecoveryThreshold
+ } else if rc.exitedRecovery {
+ // Else if exiting loss recovery:
+ // RACK.reo_wnd_persist -= 1
+ // If RACK.reo_wnd_persist <= 0:
+ // RACK.reo_wnd_incr = 1
+ rc.reoWndPersist--
+ if rc.reoWndPersist <= 0 {
+ rc.reoWndIncr = 1
+ }
+ rc.exitedRecovery = false
+ }
+
+ // Reorder window is zero during loss recovery, or when the number of
+ // DUPACKs exceeds the classic DUPACKthreshold.
+ // If RACK.reord is FALSE:
+ // If in loss recovery: (If in fast or timeout recovery)
+ // RACK.reo_wnd = 0
+ // Return
+ // Else if RACK.pkts_sacked >= RACK.dupthresh:
+ // RACK.reo_wnd = 0
+ // return
+ if !rc.reorderSeen {
+ if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery {
+ rc.reoWnd = 0
+ return
+ }
+
+ if snd.sackedOut >= nDupAckThreshold {
+ rc.reoWnd = 0
+ return
+ }
+ }
+
+ // Calculate reorder window.
+ // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr
+ // RACK.reo_wnd = min(RACK.reo_wnd, SRTT)
+ snd.rtt.Lock()
+ srtt := snd.rtt.srtt
+ snd.rtt.Unlock()
+ rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr))
+ if srtt < rc.reoWnd {
+ rc.reoWnd = srtt
+ }
+}
+
+func (rc *rackControl) exitRecovery() {
+ rc.exitedRecovery = true
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 405a6dce7..7a7c402c4 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -347,7 +347,7 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) {
r.ep.rcvListMu.Lock()
rcvClosed := r.ep.rcvClosed || r.closed
r.ep.rcvListMu.Unlock()
@@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- return true, tcpip.ErrConnectionAborted
+ return true, &tcpip.ErrConnectionAborted{}
}
if state == StateFinWait1 {
break
@@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- return true, tcpip.ErrConnectionAborted
+ return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
state := r.ep.EndpointState()
closed := r.ep.closed
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 079d90848..dfc8fd248 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -48,28 +48,6 @@ const (
MaxRetries = 15
)
-// ccState indicates the current congestion control state for this sender.
-type ccState int
-
-const (
- // Open indicates that the sender is receiving acks in order and
- // no loss or dupACK's etc have been detected.
- Open ccState = iota
- // RTORecovery indicates that an RTO has occurred and the sender
- // has entered an RTO based recovery phase.
- RTORecovery
- // FastRecovery indicates that the sender has entered FastRecovery
- // based on receiving nDupAck's. This state is entered only when
- // SACK is not in use.
- FastRecovery
- // SACKRecovery indicates that the sender has entered SACK based
- // recovery.
- SACKRecovery
- // Disorder indicates the sender either received some SACK blocks
- // or dupACK's.
- Disorder
-)
-
// congestionControl is an interface that must be implemented by any supported
// congestion control algorithm.
type congestionControl interface {
@@ -204,7 +182,7 @@ type sender struct {
maxSentAck seqnum.Value
// state is the current state of congestion control for this endpoint.
- state ccState
+ state tcpip.CongestionControlState
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
@@ -280,14 +258,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
- rc: rackControl{
- fack: iss,
- },
gso: ep.gso != nil,
}
- s.rc.init()
-
if s.gso {
s.ep.gso.MSS = uint16(maxPayloadSize)
}
@@ -295,6 +268,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.cc = s.initCongestionControl(ep.cc)
s.lr = s.initLossRecovery()
+ s.rc.init(s, iss)
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
@@ -593,7 +567,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
- s.state = RTORecovery
+ s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
@@ -1018,7 +992,7 @@ func (s *sender) sendData() {
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
+ if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
if s.sndCwnd > InitialCwnd {
s.sndCwnd = InitialCwnd
}
@@ -1062,14 +1036,14 @@ func (s *sender) enterRecovery() {
s.fr.highRxt = s.sndUna
s.fr.rescueRxt = s.sndUna
if s.ep.sackPermitted {
- s.state = SACKRecovery
+ s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
// Set TLPRxtOut to false according to
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1.
s.rc.tlpRxtOut = false
return
}
- s.state = FastRecovery
+ s.state = tcpip.FastRecovery
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
@@ -1080,7 +1054,6 @@ func (s *sender) leaveRecovery() {
// Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
-
s.cc.PostRecovery()
}
@@ -1166,7 +1139,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
s.fr.highRxt = s.sndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
- s.state = Disorder
+ s.state = tcpip.Disorder
return false
}
@@ -1217,11 +1190,13 @@ func (s *sender) isDupAck(seg *segment) bool {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
+ s.rc.setDSACKSeen(false)
+
// Look for DSACK block.
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
- s.rc.setDSACKSeen()
+ s.rc.setDSACKSeen(true)
idx = 1
n--
}
@@ -1242,7 +1217,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
for _, sb := range sackBlocks {
for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
- s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
seg.acked = true
s.sackedOut += s.pCount(seg, s.maxPayloadSize)
@@ -1412,6 +1387,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
acked := s.sndUna.Size(ack)
s.sndUna = ack
+ // The remote ACK-ing at least 1 byte is an indication that we have a
+ // full-duplex connection to the remote as the only way we will receive an
+ // ACK is if the remote received data that we previously sent.
+ //
+ // As of writing, linux seems to only confirm a route as reachable when
+ // forward progress is made which is indicated by an ACK that removes data
+ // from the retransmit queue.
+ if acked > 0 {
+ s.ep.route.ConfirmReachable()
+ }
+
ackLeft := acked
originalOutstanding := s.outstanding
for ackLeft > 0 {
@@ -1435,7 +1421,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Update the RACK fields if SACK is enabled.
if s.ep.sackPermitted && !seg.acked {
- s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
}
@@ -1464,7 +1450,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
if !s.fr.active {
s.cc.Update(originalOutstanding - s.outstanding)
if s.fr.last.LessThan(s.sndUna) {
- s.state = Open
+ s.state = tcpip.Open
+ // Update RACK when we are exiting fast or RTO
+ // recovery as described in the RFC
+ // draft-ietf-tcpm-rack-08 Section-7.2 Step 4.
+ s.rc.exitRecovery()
}
}
@@ -1488,6 +1478,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update RACK reorder window.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // * Upon receiving an ACK:
+ // * Step 4: Update RACK reordering window
+ s.rc.updateRACKReorderWindow(rcvdSeg)
+
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
if s.fr.active {
@@ -1508,7 +1504,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// sendSegment sends the specified segment.
-func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+func (s *sender) sendSegment(seg *segment) tcpip.Error {
if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
@@ -1539,7 +1535,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
-func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index f7aaee23f..ced3a9c58 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -21,13 +21,13 @@
package tcp_test
import (
+ "bytes"
"fmt"
"math"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
@@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
-
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
- half := data[:len(data)/2]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data[:len(data)/2])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
- half = data[len(data)/2:]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data[len(data)/2:])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 342eb5eb8..a6a26b705 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -15,11 +15,12 @@
package tcp_test
import (
+ "bytes"
+ "fmt"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -61,14 +62,16 @@ func TestRACKUpdate(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(maxPayload)
+ data := make([]byte, maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
xmitTime = time.Now()
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -114,13 +117,15 @@ func TestRACKDetectReorder(t *testing.T) {
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(ackNumToVerify * maxPayload)
+ data := make([]byte, ackNumToVerify*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -141,17 +146,19 @@ func TestRACKDetectReorder(t *testing.T) {
<-probeDone
}
-func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View {
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(numPackets * maxPayload)
+ data := make([]byte, numPackets*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -528,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) {
// ACK before the test completes.
<-probeDone
}
+
+func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) {
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK detects DSACK.
+ n++
+ if n < numACK {
+ return
+ }
+
+ if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT {
+ probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT)
+ return
+ }
+
+ if state.Sender.RACKState.ReoWndIncr != 1 {
+ probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr)
+ return
+ }
+
+ if state.Sender.RACKState.ReoWndPersist > 0 {
+ probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist)
+ return
+ }
+ probeDone <- nil
+ })
+}
+
+func TestRACKCheckReorderWindow(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan error)
+ const ackNumToVerify = 3
+ addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone)
+
+ const numPackets = 7
+ sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-6] packets and SACK #7 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed packets [2-6] which indicates there is reordering
+ // in the connection.
+ bytesRead += 6 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ if err := <-probeDone; err != nil {
+ t.Fatalf("unexpected values for RACK variables: %v", err)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6635bb815..5024bc925 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -15,6 +15,7 @@
package tcp_test
import (
+ "bytes"
"fmt"
"log"
"reflect"
@@ -22,7 +23,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) {
createConnectedWithSACKAndTS(c)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 93683b921..da2730e27 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io/ioutil"
"math"
+ "strings"
"testing"
"time"
@@ -26,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -48,7 +48,7 @@ type endpointTester struct {
}
// CheckReadError issues a read to the endpoint and checking for an error.
-func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) {
t.Helper()
res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
if got != want {
@@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
}
for w.N != 0 {
_, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for receive to be notified.
select {
case <-notifyRead:
@@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) {
wq.EventRegister(&waitEntry, waiter.EventHUp)
defer wq.EventUnregister(&waitEntry)
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
// Close the connection, wait for completion.
@@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) {
// Call Connect again to retreive the handshake failure status
// and stats updates.
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrAborted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{})
+ }
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
@@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
- if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{})
+ }
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
@@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
- if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ {
+ err := c.EP.Connect(connectAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
}
// Receive SYN packet with our user supplied MSS.
@@ -1347,10 +1359,9 @@ func TestTOSV4(t *testing.T) {
testV4Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1396,10 +1407,9 @@ func TestTrafficClassV6(t *testing.T) {
testV6Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1444,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("unexpected return value from Connect: %s", err)
}
@@ -1504,8 +1515,9 @@ func TestSynSent(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ err := c.EP.Connect(addr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{})
}
// Receive SYN packet.
@@ -1550,9 +1562,9 @@ func TestSynSent(t *testing.T) {
ept := endpointTester{c.EP}
if test.reset {
- ept.CheckReadError(t, tcpip.ErrConnectionRefused)
+ ept.CheckReadError(t, &tcpip.ErrConnectionRefused{})
} else {
- ept.CheckReadError(t, tcpip.ErrAborted)
+ ept.CheckReadError(t, &tcpip.ErrAborted{})
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1578,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1603,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1642,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, rcvBufSz)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1718,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1786,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1868,13 +1880,13 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1893,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -2054,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
@@ -2176,10 +2188,9 @@ func TestSimpleSend(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2217,10 +2228,9 @@ func TestZeroWindowSend(t *testing.T) {
c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2285,10 +2295,9 @@ func TestScaledWindowConnect(t *testing.T) {
})
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2317,10 +2326,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2376,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -2391,10 +2399,9 @@ func TestScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2450,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -2465,10 +2472,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2632,9 +2638,10 @@ func TestSegmentMerging(t *testing.T) {
// Send tcp.InitialCwnd number of segments to fill up
// InitialWindow but don't ACK. That should prevent
// anymore packets from going out.
+ var r bytes.Reader
for i := 0; i < tcp.InitialCwnd; i++ {
- view := buffer.NewViewFromBytes([]byte{0})
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset([]byte{0})
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2644,8 +2651,8 @@ func TestSegmentMerging(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2714,8 +2721,9 @@ func TestDelay(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2761,8 +2769,9 @@ func TestUndelay(t *testing.T) {
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2845,8 +2854,9 @@ func TestMSSNotDelayed(t *testing.T) {
allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2894,10 +2904,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
data[i] = byte(i)
}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2963,7 +2972,7 @@ func TestSetTTL(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -2973,8 +2982,11 @@ func TestSetTTL(t *testing.T) {
t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
}
// Receive SYN packet.
@@ -3034,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -3090,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -3115,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
defer c.Cleanup()
s := c.Stack()
- ch := make(chan *tcpip.Error, 1)
+ ch := make(chan tcpip.Error, 1)
f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = r.CreateEndpoint(&c.WQ)
ch <- err
})
@@ -3146,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -3165,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventOut)
defer c.WQ.EventUnregister(&we)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
// Receive SYN packet.
@@ -3276,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err {
- case tcpip.ErrWouldBlock:
+ switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case *tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
+ t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
- case tcpip.ErrConnectionReset:
+ case *tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
}
@@ -3328,9 +3344,11 @@ func TestSendOnResetConnection(t *testing.T) {
time.Sleep(1 * time.Second)
// Try to write.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ var r bytes.Reader
+ r.Reset(make([]byte, 10))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
+ t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
}
@@ -3352,7 +3370,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
defer c.WQ.EventUnregister(&waitEntry)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3409,7 +3429,9 @@ func TestMaxRTO(t *testing.T) {
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3458,7 +3480,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
}
- if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(make([]byte, tc.size))
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
pkt := c.GetPacket()
@@ -3595,8 +3619,10 @@ func TestFinWithNoPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3667,9 +3693,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
- view := buffer.NewView(10)
+ view := make([]byte, 10)
+ var r bytes.Reader
for i := tcp.InitialCwnd; i > 0; i-- {
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
}
@@ -3754,8 +3782,10 @@ func TestFinWithPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3781,7 +3811,8 @@ func TestFinWithPendingData(t *testing.T) {
})
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3841,8 +3872,10 @@ func TestFinWithPartialAck(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3879,7 +3912,8 @@ func TestFinWithPartialAck(t *testing.T) {
)
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3985,8 +4019,10 @@ func scaledSendWindow(t *testing.T, scale uint8) {
})
// Send some data. Check that it's capped by the window size.
- view := buffer.NewView(65535)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 65535)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4170,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4249,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
- ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
+ {
+ _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
+ if _, ok := err.(*tcpip.ErrClosedForReceive); !ok {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{})
+ }
}
}
@@ -4263,7 +4302,7 @@ func TestReusePort(t *testing.T) {
defer c.Cleanup()
// First case, just an endpoint that was bound.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
@@ -4293,8 +4332,11 @@ func TestReusePort(t *testing.T) {
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
c.EP.Close()
@@ -4351,9 +4393,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ s, err := ep.SocketOptions().GetSendBufferSize()
if err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
+ t.Fatalf("GetSendBufferSize failed: %s", err)
}
if int(s) != v {
@@ -4459,9 +4501,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
- t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
- }
+ ep.SocketOptions().SetSendBufferSize(149, true)
checkSendBufferSize(t, ep, 300)
@@ -4473,9 +4513,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Values above max are capped at max and then doubled.
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
- }
+ ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
// Values above max are capped at max and then doubled.
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
@@ -4505,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) {
testActions := []struct {
name string
setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
+ setBindToDeviceError tcpip.Error
getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
{"BindToExistent", nicIDPtr(321), nil, 321},
{"UnbindToDevice", nicIDPtr(0), nil, 0},
}
@@ -4529,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) {
}
}
-func makeStack() (*stack.Stack, *tcpip.Error) {
+func makeStack() (*stack.Stack, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -4599,8 +4637,11 @@ func TestSelfConnect(t *testing.T) {
wq.EventRegister(&waitEntry, waiter.EventOut)
defer wq.EventUnregister(&waitEntry)
- if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
<-notifyCh
@@ -4610,9 +4651,9 @@ func TestSelfConnect(t *testing.T) {
// Write something.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4752,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
- want := tcpip.ErrConnectStarted
+ var want tcpip.Error = &tcpip.ErrConnectStarted{}
if collides {
- want = tcpip.ErrNoPortAvailable
+ want = &tcpip.ErrNoPortAvailable{}
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
@@ -4785,12 +4826,13 @@ func TestPathMTUDiscovery(t *testing.T) {
// Send 3200 bytes of data.
const writeSize = 3200
- data := buffer.NewView(writeSize)
+ data := make([]byte, writeSize)
for i := range data {
data[i] = byte(i)
}
-
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4878,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) {
func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
cc tcpip.CongestionControlOption
- err *tcpip.Error
+ err tcpip.Error
}{
{"reno", nil},
{"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
+ {"blahblah", &tcpip.ErrNoSuchFile{}},
}
for _, tc := range testCases {
@@ -4964,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) {
func TestEndpointSetCongestionControl(t *testing.T) {
testCases := []struct {
cc tcpip.CongestionControlOption
- err *tcpip.Error
+ err tcpip.Error
}{
{"reno", nil},
{"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
+ {"blahblah", &tcpip.ErrNoSuchFile{}},
}
for _, connected := range []bool{false, true} {
@@ -4978,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5074,12 +5116,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5163,7 +5207,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -5270,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5313,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) {
for i := 0; i < listenBacklog; i++ {
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5330,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5342,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5358,7 +5402,9 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
@@ -5583,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5658,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
newEP, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5674,7 +5720,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
@@ -5692,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5733,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
defer c.WQ.EventUnregister(&we)
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5749,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5763,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5807,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5882,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
})
newEP, _, err := c.EP.Accept(nil)
-
- if err != nil && err != tcpip.ErrWouldBlock {
+ switch err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ default:
t.Fatalf("Accept failed: %s", err)
}
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -5908,7 +5957,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil {
+ var r strings.Reader
+ r.Reset(data)
+ if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5953,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
// Verify that there is only one acceptable connection at this point.
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6023,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
// Now check that there is one acceptable connections.
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6055,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
ept := endpointTester{ep}
- ept.CheckReadError(t, tcpip.ErrNotConnected)
+ ept.CheckReadError(t, &tcpip.ErrNotConnected{})
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6075,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
defer wq.EventUnregister(&we)
aep, _, err := ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6091,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
+ {
+ err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok {
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{})
+ }
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
@@ -6211,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// window increases to the full available buffer size.
for {
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
break
}
}
@@ -6335,7 +6389,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalCopied := 0
for {
res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
break
}
totalCopied += res.Count
@@ -6527,7 +6581,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6646,7 +6700,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6753,7 +6807,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6843,7 +6897,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Try to accept the connection.
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6917,7 +6971,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -7067,7 +7121,7 @@ func TestTCPCloseWithData(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -7103,10 +7157,10 @@ func TestTCPCloseWithData(t *testing.T) {
// Now write a few bytes and then close the endpoint.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7204,8 +7258,10 @@ func TestTCPUserTimeout(t *testing.T) {
}
// Send some data and wait before ACKing it.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7256,7 +7312,7 @@ func TestTCPUserTimeout(t *testing.T) {
)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7300,7 +7356,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Check that the connection is still alive.
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7339,7 +7395,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7494,8 +7550,9 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
+ _, _, err := c.EP.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
}
// Send data. This should result in an acceptable endpoint.
@@ -7552,8 +7609,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
+ _, _, err := c.EP.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7675,13 +7733,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
s := c.Stack()
testCases := []struct {
v int
- err *tcpip.Error
+ err tcpip.Error
}{
{int(tcpip.TCPTimeWaitReuseDisabled), nil},
{int(tcpip.TCPTimeWaitReuseGlobal), nil},
{int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
- {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}},
}
for _, tc := range testCases {
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index b65091c3c..5a9745ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -22,7 +22,6 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
@@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ee55f030c..b1cb9a324 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
// only endpoint instead of a default dual stack socket.
func (c *Context) CreateV6Endpoint(v6only bool) {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
// Create creates a TCP endpoint.
func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
@@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
err = c.EP.Connect(testFullAddr)
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
}
// Receive SYN packet.
@@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9f9b3d510..31a5ddce9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -97,9 +97,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
state EndpointState
@@ -111,8 +109,8 @@ type endpoint struct {
multicastNICID tcpip.NICID
portFlags ports.Flags
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
@@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// Linux defaults to TTL=1.
multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
e.ops.SetMulticastLoop(true)
+ e.ops.SetSendBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- e.sndBufSizeMax = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -217,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
-func (e *endpoint) LastError() *tcpip.Error {
+func (e *endpoint) LastError() tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
@@ -227,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -246,7 +244,7 @@ func (e *endpoint) Close() {
switch e.EndpointState() {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
@@ -284,7 +282,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
}
@@ -292,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
e.rcvMu.Lock()
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -342,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
@@ -353,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.EndpointState() {
case StateInitial:
case StateConnected:
@@ -361,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
case StateBound:
if to == nil {
- return false, tcpip.ErrDestinationRequired
+ return false, &tcpip.ErrDestinationRequired{}
}
return false, nil
default:
- return false, tcpip.ErrInvalidEndpointState
+ return false, &tcpip.ErrInvalidEndpointState{}
}
e.mu.RUnlock()
@@ -391,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -417,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -438,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
if err := e.LastError(); err != nil {
return 0, err
}
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
to := opts.To
@@ -461,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
@@ -482,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
@@ -492,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
if to.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
dst, netProto, err := e.checkV4MappedLocked(*to)
@@ -511,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, tcpip.ErrBroadcastDisabled
+ return 0, &tcpip.ErrBroadcastDisabled{}
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
so := e.SocketOptions()
if so.GetRecvError() {
so.QueueLocalErr(
- tcpip.ErrMessageTooLong,
+ &tcpip.ErrMessageTooLong{},
route.NetProto,
header.UDPMaximumPacketSize,
tcpip.FullAddress{
@@ -534,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
v,
)
}
- return 0, tcpip.ErrMessageTooLong
+ return 0, &tcpip.ErrMessageTooLong{}
}
ttl := e.ttl
@@ -584,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) {
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.MTUDiscoverOption:
// Return not supported if the value is not disabling path
// MTU discovery.
if v != tcpip.PMTUDiscoveryDont {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
case tcpip.MulticastTTLOption:
@@ -632,25 +633,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.rcvBufSizeMax = v
e.mu.Unlock()
return nil
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := e.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
- }
-
- if v < ss.Min {
- v = ss.Min
- }
- if v > ss.Max {
- v = ss.Max
- }
-
- e.mu.Lock()
- e.sndBufSizeMax = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -661,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
@@ -683,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
if nic != 0 {
if !e.stack.CheckNIC(nic) {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
} else {
nic = e.stack.CheckLocalAddress(0, netProto, addr)
if nic == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
if e.BindNICID != 0 && e.BindNICID != nic {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
e.multicastNICID = nic
@@ -701,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.AddMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
nicID := v.NIC
@@ -717,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
@@ -726,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
defer e.mu.Unlock()
if _, ok := e.multicastMemberships[memToInsert]; ok {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
@@ -737,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
nicID := v.NIC
@@ -752,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
@@ -761,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
defer e.mu.Unlock()
if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
@@ -777,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.IPv4TOSOption:
e.mu.RLock()
@@ -811,12 +793,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSizeMax
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
v := e.rcvBufSizeMax
@@ -830,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
@@ -846,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
e.mu.Unlock()
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
return nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
Data: data,
@@ -903,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -912,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (e *endpoint) Disconnect() *tcpip.Error {
+func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -930,12 +906,12 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
- var err *tcpip.Error
+ var err tcpip.Error
id = stack.TransportEndpointID{
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
- id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
@@ -950,7 +926,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -961,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
if addr.Port == 0 {
// We don't support connecting to port zero.
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
e.mu.Lock()
@@ -981,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
if nicID != 0 && nicID != e.BindNICID {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nicID = e.BindNICID
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -1023,7 +999,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
oldPortFlags := e.boundPortFlags
- id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ id, btd, err := e.registerWithStack(netProtos, id)
if err != nil {
r.Release()
return err
@@ -1031,11 +1007,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
e.boundBindToDevice = btd
+ if e.route != nil {
+ // If the endpoint was already connected then make sure we release the
+ // previous route.
+ e.route.Release()
+ }
e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
@@ -1051,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
if state := e.EndpointState(); state != StateBound && state != StateConnected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
e.shutdownFlags |= flags
@@ -1084,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen is not supported by UDP, it just fails.
-func (*endpoint) Listen(int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
@@ -1104,7 +1085,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
@@ -1112,11 +1093,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
return id, bindToDevice, err
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -1140,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// A local unicast address was specified, verify that it's valid.
nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicID == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
@@ -1148,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
- id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ id, btd, err := e.registerWithStack(netProtos, id)
if err != nil {
return err
}
@@ -1170,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1186,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -1203,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.EndpointState() != StateConnected {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return tcpip.FullAddress{
@@ -1341,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
-func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
// Update last error first.
e.lastErrorMu.Lock()
e.lastError = err
@@ -1397,7 +1378,7 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt
default:
panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
}
- e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt)
+ e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 13b72dc88..21a6aa460 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
-// saveLastError is invoked by stateify.
-func (e *endpoint) saveLastError() string {
- if e.lastError == nil {
- return ""
- }
-
- return e.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (e *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- e.lastError = tcpip.StringToError(s)
-}
-
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
@@ -91,6 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
defer e.mu.Unlock()
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
@@ -113,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
netProto = header.IPv6ProtocolNumber
}
- var err *tcpip.Error
+ var err tcpip.Error
if state == StateConnected {
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
@@ -122,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
} else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
@@ -131,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 49e673d58..705ad1f64 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
}
// CreateEndpoint creates a connected UDP endpoint for the session request.
-func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
netHdr := r.pkt.Network()
route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
+ if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 91420edd3..427fdd0c9 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
@@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int {
// ParsePorts returns the source and destination ports stored in the given udp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
h := header.UDP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
@@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4e2123fe9..64c5298d3 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -353,7 +353,7 @@ func (c *testContext) cleanup() {
func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
c.t.Helper()
- var err *tcpip.Error
+ var err tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
c.t.Fatal("NewEndpoint failed: ", err)
@@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) {
testActions := []struct {
name string
setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
+ setBindToDeviceError tcpip.Error
getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
{"BindToExistent", nicIDPtr(321), nil, 321},
{"UnbindToDevice", nicIDPtr(0), nil, 0},
}
@@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
var buf bytes.Buffer
res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for data to become available.
select {
case <-ch:
@@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
- if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
+ {
+ err := ep.Bind(addr)
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok {
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
+ }
}
}
@@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) {
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
- if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
+ {
+ err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok {
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
+ }
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
@@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) {
for _, tt := range []struct {
name string
handleLocal bool
- wantErr *tcpip.Error
+ wantErr tcpip.Error
wantInvalidSource uint64
}{
{"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
} {
t.Run(tt.name, func(t *testing.T) {
c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
@@ -959,15 +965,16 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
// testFailingWrite sends a packet of the given test flow into the UDP endpoint
// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- payload := buffer.View(newPayload())
- _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ var r bytes.Reader
+ r.Reset(newPayload())
+ _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
c.checkEndpointWriteStats(1, epstats, gotErr)
@@ -1007,8 +1014,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
}
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %s", err)
}
@@ -1089,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
testWrite(c, unicastV6)
// Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
const want = 1
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
@@ -1110,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
testWrite(c, unicastV4in6)
// Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+ testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
}
func TestV4WriteOnV6Only(t *testing.T) {
@@ -1120,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) {
c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
+ testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
@@ -1135,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+ testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
}
func TestV6WriteOnConnected(t *testing.T) {
@@ -1183,8 +1192,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
writeOpts := tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
}
@@ -1192,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
}
- if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
- c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ {
+ err := c.ep.LastError()
+ if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
}
})
}
@@ -2303,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) {
t.Fatalf("Shutdown failed: %s", err)
}
- testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+ testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
}
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
+ switch err.(type) {
case nil:
want.PacketsSent.IncrementBy(incr)
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
want.WriteErrors.WriteClosed.IncrementBy(incr)
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
want.SendErrors.NoRoute.IncrementBy(incr)
default:
want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
@@ -2327,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE
}
}
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
- case nil, tcpip.ErrWouldBlock:
- case tcpip.ErrClosedForReceive:
+ switch err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ case *tcpip.ErrClosedForReceive:
want.ReadErrors.ReadClosed.IncrementBy(incr)
default:
c.t.Errorf("Endpoint error missing stats update err %v", err)
@@ -2497,31 +2511,54 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
defer ep.Close()
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ var r bytes.Reader
+ data := []byte{1, 2, 3, 4}
to := tcpip.FullAddress{
Addr: test.remoteAddr,
Port: 80,
}
opts := tcpip.WriteOptions{To: &to}
- expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
+ if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
+ return nil
+ }
+ return &tcpip.ErrBroadcastDisabled{}
+ }
if !test.requiresBroadcastOpt {
expectedErrWithoutBcastOpt = nil
}
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
+ r.Reset(data)
+ {
+ n, err := ep.Write(&r, opts)
+ if expectedErrWithoutBcastOpt != nil {
+ if want := expectedErrWithoutBcastOpt(err); want != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
+ }
+ } else if err != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
+ }
}
ep.SocketOptions().SetBroadcast(true)
- if n, err := ep.Write(data, opts); err != nil {
+ r.Reset(data)
+ if n, err := ep.Write(&r, opts); err != nil {
t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
}
ep.SocketOptions().SetBroadcast(false)
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
+ r.Reset(data)
+ {
+ n, err := ep.Write(&r, opts)
+ if expectedErrWithoutBcastOpt != nil {
+ if want := expectedErrWithoutBcastOpt(err); want != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
+ }
+ } else if err != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
+ }
}
})
}
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 79db8895b..dc2571154 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er
// Reader returns an io.Reader that reads from s. Reads beyond the end of s
// return io.EOF. The preconditions that apply to s.CopyIn also apply to the
// returned io.Reader.Read.
-func (s IOSequence) Reader(ctx context.Context) io.Reader {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// Writer returns an io.Writer that writes to s. Writes beyond the end of s
// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also
// apply to the returned io.Writer.Write.
-func (s IOSequence) Writer(ctx context.Context) io.Writer {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when
// attempting to write beyond the end of the IOSequence.
var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence")
-type ioSequenceReadWriter struct {
+// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence.
+type IOSequenceReadWriter struct {
ctx context.Context
s IOSequence
}
// Read implements io.Reader.Read.
-func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) {
n, err := rw.s.CopyIn(rw.ctx, dst)
rw.s = rw.s.DropFirst(n)
if err == nil && rw.s.NumBytes() == 0 {
@@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
return n, err
}
+// Len implements tcpip.Payloader.
+func (rw *IOSequenceReadWriter) Len() int {
+ return int(rw.s.NumBytes())
+}
+
// Write implements io.Writer.Write.
-func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) {
n, err := rw.s.CopyOut(rw.ctx, src)
rw.s = rw.s.DropFirst(n)
if err == nil && n < len(src) {
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index a92ae046d..3bbf86534 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -52,7 +52,7 @@ func waitForProcessList(cont *Container, want []*control.Process) error {
cb := func() error {
got, err := cont.Processes()
if err != nil {
- err = fmt.Errorf("error getting process data from container: %v", err)
+ err = fmt.Errorf("error getting process data from container: %w", err)
return &backoff.PermanentError{Err: err}
}
if !procListsEqual(got, want) {
@@ -64,11 +64,30 @@ func waitForProcessList(cont *Container, want []*control.Process) error {
return testutil.Poll(cb, 30*time.Second)
}
+// waitForProcess waits for the given process to show up in the container.
+func waitForProcess(cont *Container, want *control.Process) error {
+ cb := func() error {
+ gots, err := cont.Processes()
+ if err != nil {
+ err = fmt.Errorf("error getting process data from container: %w", err)
+ return &backoff.PermanentError{Err: err}
+ }
+ for _, got := range gots {
+ if procEqual(got, want) {
+ return nil
+ }
+ }
+ return fmt.Errorf("container got process list: %s, want: %+v", procListToString(gots), want)
+ }
+ // Gives plenty of time as tests can run slow under --race.
+ return testutil.Poll(cb, 30*time.Second)
+}
+
func waitForProcessCount(cont *Container, want int) error {
cb := func() error {
pss, err := cont.Processes()
if err != nil {
- err = fmt.Errorf("error getting process data from container: %v", err)
+ err = fmt.Errorf("error getting process data from container: %w", err)
return &backoff.PermanentError{Err: err}
}
if got := len(pss); got != want {
@@ -101,28 +120,32 @@ func procListsEqual(gots, wants []*control.Process) bool {
return false
}
for i := range gots {
- got := gots[i]
- want := wants[i]
-
- if want.UID != math.MaxUint32 && want.UID != got.UID {
- return false
- }
- if want.PID != -1 && want.PID != got.PID {
- return false
- }
- if want.PPID != -1 && want.PPID != got.PPID {
- return false
- }
- if len(want.TTY) != 0 && want.TTY != got.TTY {
- return false
- }
- if len(want.Cmd) != 0 && want.Cmd != got.Cmd {
+ if !procEqual(gots[i], wants[i]) {
return false
}
}
return true
}
+func procEqual(got, want *control.Process) bool {
+ if want.UID != math.MaxUint32 && want.UID != got.UID {
+ return false
+ }
+ if want.PID != -1 && want.PID != got.PID {
+ return false
+ }
+ if want.PPID != -1 && want.PPID != got.PPID {
+ return false
+ }
+ if len(want.TTY) != 0 && want.TTY != got.TTY {
+ return false
+ }
+ if len(want.Cmd) != 0 && want.Cmd != got.Cmd {
+ return false
+ }
+ return true
+}
+
type processBuilder struct {
process control.Process
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 044eec6fe..bc802e075 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -1708,12 +1708,9 @@ func TestMultiContainerHomeEnvDir(t *testing.T) {
t.Errorf("wait on child container: %v", err)
}
- // Wait for the root container to run.
- expectedPL := []*control.Process{
- newProcessBuilder().Cmd("sh").Process(),
- newProcessBuilder().Cmd("sleep").Process(),
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ // Wait until after `env` has executed.
+ expectedProc := newProcessBuilder().Cmd("sleep").Process()
+ if err := waitForProcess(containers[0], expectedProc); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
}
@@ -1831,7 +1828,7 @@ func TestDuplicateEnvVariable(t *testing.T) {
cmd1 := fmt.Sprintf("env > %q; sleep 1000", files[0].Name())
cmd2 := fmt.Sprintf("env > %q", files[1].Name())
cmdExec := fmt.Sprintf("env > %q", files[2].Name())
- testSpecs, ids := createSpecs([]string{"/bin/bash", "-c", cmd1}, []string{"/bin/bash", "-c", cmd2})
+ testSpecs, ids := createSpecs([]string{"/bin/sh", "-c", cmd1}, []string{"/bin/sh", "-c", cmd2})
testSpecs[0].Process.Env = append(testSpecs[0].Process.Env, "VAR=foo", "VAR=bar")
testSpecs[1].Process.Env = append(testSpecs[1].Process.Env, "VAR=foo", "VAR=bar")
@@ -1841,12 +1838,9 @@ func TestDuplicateEnvVariable(t *testing.T) {
}
defer cleanup()
- // Wait for the `env` from the root container to finish.
- expectedPL := []*control.Process{
- newProcessBuilder().Cmd("bash").Process(),
- newProcessBuilder().Cmd("sleep").Process(),
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ // Wait until after `env` has executed.
+ expectedProc := newProcessBuilder().Cmd("sleep").Process()
+ if err := waitForProcess(containers[0], expectedProc); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
}
if ws, err := containers[1].Wait(); err != nil {
@@ -1856,8 +1850,8 @@ func TestDuplicateEnvVariable(t *testing.T) {
}
execArgs := &control.ExecArgs{
- Filename: "/bin/bash",
- Argv: []string{"/bin/bash", "-c", cmdExec},
+ Filename: "/bin/sh",
+ Argv: []string{"/bin/sh", "-c", cmdExec},
Envv: []string{"VAR=foo", "VAR=bar"},
}
if ws, err := containers[0].executeSync(execArgs); err != nil || ws.ExitStatus() != 0 {
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index c56e1d4d0..3280b74fe 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -12,7 +12,6 @@ go_library(
],
visibility = ["//runsc:__subpackages__"],
deps = [
- "//pkg/abi/linux",
"//pkg/cleanup",
"//pkg/fd",
"//pkg/log",
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index c3bba0973..cfa3796b1 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -31,7 +31,6 @@ import (
"strconv"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
@@ -49,13 +48,6 @@ const (
allowedOpenFlags = unix.O_TRUNC
)
-var (
- // Remember the process uid/gid to skip chown calls when file owner/group
- // doesn't need to be changed.
- processUID = p9.UID(os.Getuid())
- processGID = p9.GID(os.Getgid())
-)
-
// join is equivalent to path.Join() but skips path.Clean() which is expensive.
func join(parent, child string) string {
if child == "." || child == ".." {
@@ -374,7 +366,24 @@ func fstat(fd int) (unix.Stat_t, error) {
}
func fchown(fd int, uid p9.UID, gid p9.GID) error {
- return unix.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW)
+ return unix.Fchownat(fd, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW)
+}
+
+func setOwnerIfNeeded(fd int, uid p9.UID, gid p9.GID) (unix.Stat_t, error) {
+ stat, err := fstat(fd)
+ if err != nil {
+ return unix.Stat_t{}, err
+ }
+
+ // Change ownership if not set accordinly.
+ if uint32(uid) != stat.Uid || uint32(gid) != stat.Gid {
+ if err := fchown(fd, uid, gid); err != nil {
+ return unix.Stat_t{}, err
+ }
+ stat.Uid = uint32(uid)
+ stat.Gid = uint32(gid)
+ }
+ return stat, nil
}
// Open implements p9.File.
@@ -457,12 +466,7 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode,
})
defer cu.Clean()
- if uid != processUID || gid != processGID {
- if err := fchown(child.FD(), uid, gid); err != nil {
- return nil, nil, p9.QID{}, 0, extractErrno(err)
- }
- }
- stat, err := fstat(child.FD())
+ stat, err := setOwnerIfNeeded(child.FD(), uid, gid)
if err != nil {
return nil, nil, p9.QID{}, 0, extractErrno(err)
}
@@ -505,12 +509,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
}
defer f.Close()
- if uid != processUID || gid != processGID {
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
- }
- }
- stat, err := fstat(f.FD())
+ stat, err := setOwnerIfNeeded(f.FD(), uid, gid)
if err != nil {
return p9.QID{}, extractErrno(err)
}
@@ -734,15 +733,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
if valid.ATime || valid.MTime {
utimes := [2]unix.Timespec{
- {Sec: 0, Nsec: linux.UTIME_OMIT},
- {Sec: 0, Nsec: linux.UTIME_OMIT},
+ {Sec: 0, Nsec: unix.UTIME_OMIT},
+ {Sec: 0, Nsec: unix.UTIME_OMIT},
}
if valid.ATime {
if valid.ATimeNotSystemTime {
utimes[0].Sec = int64(attr.ATimeSeconds)
utimes[0].Nsec = int64(attr.ATimeNanoSeconds)
} else {
- utimes[0].Nsec = linux.UTIME_NOW
+ utimes[0].Nsec = unix.UTIME_NOW
}
}
if valid.MTime {
@@ -750,7 +749,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
utimes[1].Sec = int64(attr.MTimeSeconds)
utimes[1].Nsec = int64(attr.MTimeNanoSeconds)
} else {
- utimes[1].Nsec = linux.UTIME_NOW
+ utimes[1].Nsec = unix.UTIME_NOW
}
}
@@ -764,7 +763,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
}
defer unix.Close(parent)
- if tErr := utimensat(parent, path.Base(l.hostPath), utimes, linux.AT_SYMLINK_NOFOLLOW); tErr != nil {
+ if tErr := utimensat(parent, path.Base(l.hostPath), utimes, unix.AT_SYMLINK_NOFOLLOW); tErr != nil {
log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, tErr)
err = extractErrno(tErr)
}
@@ -779,15 +778,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
}
if valid.UID || valid.GID {
- uid := -1
+ uid := p9.NoUID
if valid.UID {
- uid = int(attr.UID)
+ uid = attr.UID
}
- gid := -1
+ gid := p9.NoGID
if valid.GID {
- gid = int(attr.GID)
+ gid = attr.GID
}
- if oErr := unix.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oErr != nil {
+ if oErr := fchown(f.FD(), uid, gid); oErr != nil {
log.Debugf("SetAttr fchownat failed %q, err: %v", l.hostPath, oErr)
err = extractErrno(oErr)
}
@@ -900,12 +899,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
}
defer f.Close()
- if uid != processUID || gid != processGID {
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
- }
- }
- stat, err := fstat(f.FD())
+ stat, err := setOwnerIfNeeded(f.FD(), uid, gid)
if err != nil {
return p9.QID{}, extractErrno(err)
}
@@ -921,7 +915,7 @@ func (l *localFile) Link(target p9.File, newName string) error {
}
targetFile := target.(*localFile)
- if err := unix.Linkat(targetFile.file.FD(), "", l.file.FD(), newName, linux.AT_EMPTY_PATH); err != nil {
+ if err := unix.Linkat(targetFile.file.FD(), "", l.file.FD(), newName, unix.AT_EMPTY_PATH); err != nil {
return extractErrno(err)
}
return nil
@@ -959,12 +953,7 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid
}
defer child.Close()
- if uid != processUID || gid != processGID {
- if err := fchown(child.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
- }
- }
- stat, err := fstat(child.FD())
+ stat, err := setOwnerIfNeeded(child.FD(), uid, gid)
if err != nil {
return p9.QID{}, extractErrno(err)
}
@@ -1113,7 +1102,8 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) {
// mappings, the app path may have fit in the sockaddr, but we can't
// fit f.path in our sockaddr. We'd need to redirect through a shorter
// path in order to actually connect to this socket.
- if len(l.hostPath) > linux.UnixPathMax {
+ const UNIX_PATH_MAX = 108 // defined in afunix.h
+ if len(l.hostPath) > UNIX_PATH_MAX {
return nil, unix.ECONNREFUSED
}
diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go
index c46958185..29ebf8500 100644
--- a/runsc/fsgofer/fsgofer_amd64_unsafe.go
+++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go
@@ -20,7 +20,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/syserr"
)
@@ -39,7 +38,7 @@ func statAt(dirFd int, name string) (unix.Stat_t, error) {
uintptr(dirFd),
uintptr(namePtr),
uintptr(statPtr),
- linux.AT_SYMLINK_NOFOLLOW,
+ unix.AT_SYMLINK_NOFOLLOW,
0,
0); errno != 0 {
diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go
index 491460718..9fd5d0871 100644
--- a/runsc/fsgofer/fsgofer_arm64_unsafe.go
+++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go
@@ -20,7 +20,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/syserr"
)
@@ -39,7 +38,7 @@ func statAt(dirFd int, name string) (unix.Stat_t, error) {
uintptr(dirFd),
uintptr(namePtr),
uintptr(statPtr),
- linux.AT_SYMLINK_NOFOLLOW,
+ unix.AT_SYMLINK_NOFOLLOW,
0,
0); errno != 0 {
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index c5daebe5e..99ea9bd32 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -32,6 +32,9 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
+// Nodoby is the standard UID/GID for the nobody user/group.
+const nobody = 65534
+
var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
var (
@@ -281,6 +284,92 @@ func TestCreate(t *testing.T) {
})
}
+func checkIDs(f p9.File, uid, gid int) error {
+ _, _, stat, err := f.GetAttr(p9.AttrMask{UID: true, GID: true})
+ if err != nil {
+ return fmt.Errorf("GetAttr() failed, err: %v", err)
+ }
+ if want := p9.UID(uid); stat.UID != want {
+ return fmt.Errorf("Wrong UID, want: %v, got: %v", want, stat.UID)
+ }
+ if want := p9.GID(gid); stat.GID != want {
+ return fmt.Errorf("Wrong GID, want: %v, got: %v", want, stat.GID)
+ }
+ return nil
+}
+
+// TestCreateSetGID checks files/dirs/symlinks are created with the proper
+// owner when the parent directory has setgid set,
+func TestCreateSetGID(t *testing.T) {
+ if !specutils.HasCapabilities(capability.CAP_CHOWN) {
+ t.Skipf("Test requires CAP_CHOWN")
+ }
+
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ // Change group and set setgid to the parent dir.
+ if err := unix.Chown(s.file.hostPath, os.Getuid(), nobody); err != nil {
+ t.Fatalf("Chown() failed: %v", err)
+ }
+ if err := unix.Chmod(s.file.hostPath, 02777); err != nil {
+ t.Fatalf("Chmod() failed: %v", err)
+ }
+
+ t.Run("create", func(t *testing.T) {
+ _, l, _, _, err := s.file.Create("test", p9.ReadOnly, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ t.Fatalf("WriteAt() failed: %v", err)
+ }
+ defer l.Close()
+ if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil {
+ t.Error(err)
+ }
+ })
+
+ t.Run("mkdir", func(t *testing.T) {
+ _, err := s.file.Mkdir("test-dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ if err != nil {
+ t.Fatalf("WriteAt() failed: %v", err)
+ }
+ _, l, err := s.file.Walk([]string{"test-dir"})
+ if err != nil {
+ t.Fatalf("Walk() failed: %v", err)
+ }
+ defer l.Close()
+ if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil {
+ t.Error(err)
+ }
+ })
+
+ t.Run("symlink", func(t *testing.T) {
+ if _, err := s.file.Symlink("/some/target", "symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("Symlink() failed: %v", err)
+ }
+ _, l, err := s.file.Walk([]string{"symlink"})
+ if err != nil {
+ t.Fatalf("Walk() failed, err: %v", err)
+ }
+ defer l.Close()
+ if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil {
+ t.Error(err)
+ }
+ })
+
+ t.Run("mknod", func(t *testing.T) {
+ if _, err := s.file.Mknod("nod", p9.ModeRegular|0777, 0, 0, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("Mknod() failed: %v", err)
+ }
+ _, l, err := s.file.Walk([]string{"nod"})
+ if err != nil {
+ t.Fatalf("Walk() failed, err: %v", err)
+ }
+ defer l.Close()
+ if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil {
+ t.Error(err)
+ }
+ })
+ })
+}
+
// TestReadWriteDup tests that a file opened in any mode can be dup'ed and
// reopened in any other mode.
func TestReadWriteDup(t *testing.T) {
@@ -458,7 +547,7 @@ func TestSetAttrTime(t *testing.T) {
}
func TestSetAttrOwner(t *testing.T) {
- if os.Getuid() != 0 {
+ if !specutils.HasCapabilities(capability.CAP_CHOWN) {
t.Skipf("SetAttr(owner) test requires CAP_CHOWN, running as %d", os.Getuid())
}
@@ -477,7 +566,7 @@ func TestSetAttrOwner(t *testing.T) {
}
func TestLink(t *testing.T) {
- if os.Getuid() != 0 {
+ if !specutils.HasCapabilities(capability.CAP_DAC_READ_SEARCH) {
t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid())
}
runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
@@ -995,7 +1084,6 @@ func BenchmarkCreateDiffOwner(b *testing.B) {
files := make([]p9.File, 0, 500)
fds := make([]*fd.FD, 0, 500)
gid := p9.GID(os.Getgid())
- const nobody = 65534
b.ResetTimer()
for i := 0; i < b.N; i++ {
diff --git a/runsc/mitigate/BUILD b/runsc/mitigate/BUILD
new file mode 100644
index 000000000..3b0342d18
--- /dev/null
+++ b/runsc/mitigate/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "mitigate",
+ srcs = [
+ "cpu.go",
+ "mitigate.go",
+ ],
+ deps = ["@in_gopkg_yaml_v2//:go_default_library"],
+)
+
+go_test(
+ name = "mitigate_test",
+ size = "small",
+ srcs = ["cpu_test.go"],
+ library = ":mitigate",
+ deps = ["@com_github_google_go_cmp//cmp:go_default_library"],
+)
diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go
new file mode 100644
index 000000000..113b98159
--- /dev/null
+++ b/runsc/mitigate/cpu.go
@@ -0,0 +1,235 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mitigate
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+const (
+ // constants of coomm
+ meltdown = "cpu_meltdown"
+ l1tf = "l1tf"
+ mds = "mds"
+ swapgs = "swapgs"
+ taa = "taa"
+)
+
+const (
+ processorKey = "processor"
+ vendorIDKey = "vendor_id"
+ cpuFamilyKey = "cpu family"
+ modelKey = "model"
+ coreIDKey = "core id"
+ bugsKey = "bugs"
+)
+
+// getCPUSet returns cpu structs from reading /proc/cpuinfo.
+func getCPUSet(data string) ([]*cpu, error) {
+ // Each processor entry should start with the
+ // processor key. Find the beginings of each.
+ r := buildRegex(processorKey, `\d+`)
+ indices := r.FindAllStringIndex(data, -1)
+ if len(indices) < 1 {
+ return nil, fmt.Errorf("no cpus found for: %s", data)
+ }
+
+ // Add the ending index for last entry.
+ indices = append(indices, []int{len(data), -1})
+
+ // Valid cpus are now defined by strings in between
+ // indexes (e.g. data[index[i], index[i+1]]).
+ // There should be len(indicies) - 1 CPUs
+ // since the last index is the end of the string.
+ var cpus = make([]*cpu, 0, len(indices)-1)
+ // Find each string that represents a CPU. These begin "processor".
+ for i := 1; i < len(indices); i++ {
+ start := indices[i-1][0]
+ end := indices[i][0]
+ // Parse the CPU entry, which should be between start/end.
+ c, err := getCPU(data[start:end])
+ if err != nil {
+ return nil, err
+ }
+ cpus = append(cpus, c)
+ }
+ return cpus, nil
+}
+
+// type cpu represents pertinent info about a cpu.
+type cpu struct {
+ processorNumber int64 // the processor number of this CPU.
+ vendorID string // the vendorID of CPU (e.g. AuthenticAMD).
+ cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake).
+ model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake).
+ coreID int64 // This CPU's core id to match Hyperthread Pairs
+ bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field.
+}
+
+// getCPU parses a CPU from a single cpu entry from /proc/cpuinfo.
+func getCPU(data string) (*cpu, error) {
+ processor, err := parseProcessor(data)
+ if err != nil {
+ return nil, err
+ }
+
+ vendorID, err := parseVendorID(data)
+ if err != nil {
+ return nil, err
+ }
+
+ cpuFamily, err := parseCPUFamily(data)
+ if err != nil {
+ return nil, err
+ }
+
+ model, err := parseModel(data)
+ if err != nil {
+ return nil, err
+ }
+
+ coreID, err := parseCoreID(data)
+ if err != nil {
+ return nil, err
+ }
+
+ bugs, err := parseBugs(data)
+ if err != nil {
+ return nil, err
+ }
+
+ return &cpu{
+ processorNumber: processor,
+ vendorID: vendorID,
+ cpuFamily: cpuFamily,
+ model: model,
+ coreID: coreID,
+ bugs: bugs,
+ }, nil
+}
+
+// List of pertinent side channel vulnerablilites.
+// For mds, see: https://www.kernel.org/doc/html/latest/admin-guide/hw-vuln/mds.html.
+var vulnerabilities = []string{
+ meltdown,
+ l1tf,
+ mds,
+ swapgs,
+ taa,
+}
+
+// isVulnerable checks if a CPU is vulnerable to pertinent bugs.
+func (c *cpu) isVulnerable() bool {
+ for _, bug := range vulnerabilities {
+ if _, ok := c.bugs[bug]; ok {
+ return true
+ }
+ }
+ return false
+}
+
+// similarTo checks family/model/bugs fields for equality of two
+// processors.
+func (c *cpu) similarTo(other *cpu) bool {
+ if c.vendorID != other.vendorID {
+ return false
+ }
+
+ if other.cpuFamily != c.cpuFamily {
+ return false
+ }
+
+ if other.model != c.model {
+ return false
+ }
+
+ if len(other.bugs) != len(c.bugs) {
+ return false
+ }
+
+ for bug := range c.bugs {
+ if _, ok := other.bugs[bug]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+// parseProcessor grabs the processor field from /proc/cpuinfo output.
+func parseProcessor(data string) (int64, error) {
+ return parseIntegerResult(data, processorKey)
+}
+
+// parseVendorID grabs the vendor_id field from /proc/cpuinfo output.
+func parseVendorID(data string) (string, error) {
+ return parseRegex(data, vendorIDKey, `[\w\d]+`)
+}
+
+// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output.
+func parseCPUFamily(data string) (int64, error) {
+ return parseIntegerResult(data, cpuFamilyKey)
+}
+
+// parseModel grabs the model field from /proc/cpuinfo output.
+func parseModel(data string) (int64, error) {
+ return parseIntegerResult(data, modelKey)
+}
+
+// parseCoreID parses the core id field.
+func parseCoreID(data string) (int64, error) {
+ return parseIntegerResult(data, coreIDKey)
+}
+
+// parseBugs grabs the bugs field from /proc/cpuinfo output.
+func parseBugs(data string) (map[string]struct{}, error) {
+ result, err := parseRegex(data, bugsKey, `[\d\w\s]*`)
+ if err != nil {
+ return nil, err
+ }
+ bugs := strings.Split(result, " ")
+ ret := make(map[string]struct{}, len(bugs))
+ for _, bug := range bugs {
+ ret[bug] = struct{}{}
+ }
+ return ret, nil
+}
+
+// parseIntegerResult parses fields expecting an integer.
+func parseIntegerResult(data, key string) (int64, error) {
+ result, err := parseRegex(data, key, `\d+`)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.ParseInt(result, 0, 64)
+}
+
+// buildRegex builds a regex for parsing each CPU field.
+func buildRegex(key, match string) *regexp.Regexp {
+ reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key)
+ return regexp.MustCompile(reg)
+}
+
+// parseRegex parses data with key inserted into a standard regex template.
+func parseRegex(data, key, match string) (string, error) {
+ r := buildRegex(key, match)
+ matches := r.FindStringSubmatch(data)
+ if len(matches) < 2 {
+ return "", fmt.Errorf("failed to match key %s: %s", key, data)
+ }
+ return matches[1], nil
+}
diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go
new file mode 100644
index 000000000..77b714a02
--- /dev/null
+++ b/runsc/mitigate/cpu_test.go
@@ -0,0 +1,368 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mitigate
+
+import (
+ "io/ioutil"
+ "strings"
+ "testing"
+)
+
+// CPU info for a Intel CascadeLake processor. Both Skylake and CascadeLake have
+// the same family/model numbers, but with different bugs (e.g. skylake has
+// cpu_meltdown).
+var cascadeLake = &cpu{
+ vendorID: "GenuineIntel",
+ cpuFamily: 6,
+ model: 85,
+ bugs: map[string]struct{}{
+ "spectre_v1": struct{}{},
+ "spectre_v2": struct{}{},
+ "spec_store_bypass": struct{}{},
+ mds: struct{}{},
+ swapgs: struct{}{},
+ taa: struct{}{},
+ },
+}
+
+// TestGetCPU tests basic parsing of single CPU strings from reading
+// /proc/cpuinfo.
+func TestGetCPU(t *testing.T) {
+ data := `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+core id : 0
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit
+`
+ want := cpu{
+ processorNumber: 0,
+ vendorID: "GenuineIntel",
+ cpuFamily: 6,
+ model: 85,
+ coreID: 0,
+ bugs: map[string]struct{}{
+ "cpu_meltdown": struct{}{},
+ "spectre_v1": struct{}{},
+ "spectre_v2": struct{}{},
+ "spec_store_bypass": struct{}{},
+ "l1tf": struct{}{},
+ "mds": struct{}{},
+ "swapgs": struct{}{},
+ "taa": struct{}{},
+ "itlb_multihit": struct{}{},
+ },
+ }
+
+ got, err := getCPU(data)
+ if err != nil {
+ t.Fatalf("getCpu failed with error: %v", err)
+ }
+
+ if !want.similarTo(got) {
+ t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want)
+ }
+
+ if !got.isVulnerable() {
+ t.Fatalf("Failed: cpu should be vulnerable.")
+ }
+}
+
+func TestInvalid(t *testing.T) {
+ result, err := getCPUSet(`something not a processor`)
+ if err == nil {
+ t.Fatalf("getCPU set didn't return an error: %+v", result)
+ }
+
+ if !strings.Contains(err.Error(), "no cpus") {
+ t.Fatalf("Incorrect error returned: %v", err)
+ }
+}
+
+// TestCPUSet tests getting the right number of CPUs from
+// parsing full output of /proc/cpuinfo.
+func TestCPUSet(t *testing.T) {
+ data := `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:
+
+processor : 1
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 1
+initial apicid : 1
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:
+`
+ cpuSet, err := getCPUSet(data)
+ if err != nil {
+ t.Fatalf("getCPUSet failed: %v", err)
+ }
+
+ wantCPULen := 2
+ if len(cpuSet) != wantCPULen {
+ t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet))
+ }
+
+ wantCPU := cpu{
+ vendorID: "GenuineIntel",
+ cpuFamily: 6,
+ model: 63,
+ bugs: map[string]struct{}{
+ "cpu_meltdown": struct{}{},
+ "spectre_v1": struct{}{},
+ "spectre_v2": struct{}{},
+ "spec_store_bypass": struct{}{},
+ "l1tf": struct{}{},
+ "mds": struct{}{},
+ "swapgs": struct{}{},
+ },
+ }
+
+ for _, c := range cpuSet {
+ if !wantCPU.similarTo(c) {
+ t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU)
+ }
+ }
+}
+
+// TestReadFile is a smoke test for parsing methods.
+func TestReadFile(t *testing.T) {
+ data, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ t.Fatalf("Failed to read cpuinfo: %v", err)
+ }
+
+ set, err := getCPUSet(string(data))
+ if err != nil {
+ t.Fatalf("Failed to parse CPU data %v\n%s", err, data)
+ }
+
+ if len(set) < 1 {
+ t.Fatalf("Failed to parse any CPUs: %d", len(set))
+ }
+
+ for _, c := range set {
+ t.Logf("CPU: %+v: %t", c, c.isVulnerable())
+ }
+}
+
+// TestVulnerable tests if the isVulnerable method is correct
+// among known CPUs in GCP.
+func TestVulnerable(t *testing.T) {
+ const haswell = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 4
+core id : 0
+cpu cores : 2
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const skylake = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+model name : Intel(R) Xeon(R) CPU @ 2.00GHz
+stepping : 3
+microcode : 0x1
+cpu MHz : 2000.180
+cache size : 39424 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
+bogomips : 4000.36
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const cascade = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+model name : Intel(R) Xeon(R) CPU
+stepping : 7
+microcode : 0x1
+cpu MHz : 2800.198
+cache size : 33792 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2
+ ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu
+lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr
+efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r
+tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a
+rat avx512_vnni md_clear arch_capabilities
+bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa
+bogomips : 5600.39
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const amd = `processor : 0
+vendor_id : AuthenticAMD
+cpu family : 23
+model : 49
+model name : AMD EPYC 7B12
+stepping : 0
+microcode : 0x1000065
+cpu MHz : 2250.000
+cache size : 512 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
+bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+bogomips : 4500.00
+TLB size : 3072 4K pages
+clflush size : 64
+cache_alignment : 64
+address sizes : 48 bits physical, 48 bits virtual
+power management:`
+
+ for _, tc := range []struct {
+ name string
+ cpuString string
+ vulnerable bool
+ }{
+ {
+ name: "haswell",
+ cpuString: haswell,
+ vulnerable: true,
+ }, {
+ name: "skylake",
+ cpuString: skylake,
+ vulnerable: true,
+ }, {
+ name: "cascadeLake",
+ cpuString: cascade,
+ vulnerable: false,
+ }, {
+ name: "amd",
+ cpuString: amd,
+ vulnerable: false,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ set, err := getCPUSet(tc.cpuString)
+ if err != nil {
+ t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString)
+ }
+
+ if len(set) < 1 {
+ t.Fatalf("Returned empty cpu set: %v", set)
+ }
+
+ for _, c := range set {
+ got := func() bool {
+ if cascadeLake.similarTo(c) {
+ return false
+ }
+ return c.isVulnerable()
+ }()
+
+ if got != tc.vulnerable {
+ t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got)
+ }
+ }
+ })
+ }
+}
diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go
new file mode 100644
index 000000000..51d5449b6
--- /dev/null
+++ b/runsc/mitigate/mitigate.go
@@ -0,0 +1,20 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mitigate provides libraries for the mitigate command. The
+// mitigate command mitigates side channel attacks such as MDS. Mitigate
+// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online. In addition,
+// the mitigate also handles computing available CPU in kubernetes kube_config
+// files.
+package mitigate
diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go
index 63f692ca9..300c21ceb 100644
--- a/test/fsstress/fsstress_test.go
+++ b/test/fsstress/fsstress_test.go
@@ -41,7 +41,7 @@ func fsstress(t *testing.T, dir string) {
image = "basic/fsstress"
)
seed := strconv.FormatUint(uint64(rand.Uint32()), 10)
- args := []string{"-d", dir, "-n", operations, "-p", processes, "-seed", seed, "-X"}
+ args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"}
t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, ""))
out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...)
if err != nil {
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index 66453772a..ae4bba847 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -15,7 +15,9 @@ go_library(
],
visibility = ["//test/iptables:__subpackages__"],
deps = [
+ "//pkg/binary",
"//pkg/test/testutil",
+ "//pkg/usermem",
],
)
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index 37a1a6694..c47660026 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -51,6 +51,12 @@ func init() {
RegisterTestCase(FilterInputInvertDestination{})
RegisterTestCase(FilterInputSource{})
RegisterTestCase(FilterInputInvertSource{})
+ RegisterTestCase(FilterInputInterfaceAccept{})
+ RegisterTestCase(FilterInputInterfaceDrop{})
+ RegisterTestCase(FilterInputInterface{})
+ RegisterTestCase(FilterInputInterfaceBeginsWith{})
+ RegisterTestCase(FilterInputInterfaceInvertDrop{})
+ RegisterTestCase(FilterInputInterfaceInvertAccept{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -744,3 +750,195 @@ func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, i
func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
return sendUDPLoop(ctx, ip, acceptPort)
}
+
+// FilterInputInterfaceAccept tests that packets are accepted from interface
+// matching the iptables rule.
+type FilterInputInterfaceAccept struct{ localCase }
+
+var _ TestCase = FilterInputInterfaceAccept{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterfaceAccept) Name() string {
+ return "FilterInputInterfaceAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "ACCEPT"); err != nil {
+ return err
+ }
+ if err := listenUDP(ctx, acceptPort); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInterfaceDrop tests that packets are dropped from interface
+// matching the iptables rule.
+type FilterInputInterfaceDrop struct{ localCase }
+
+var _ TestCase = FilterInputInterfaceDrop{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterfaceDrop) Name() string {
+ return "FilterInputInterfaceDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "DROP"); err != nil {
+ return err
+ }
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, acceptPort); err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ return fmt.Errorf("error reading: %w", err)
+ }
+ return fmt.Errorf("packets should have been dropped, but got a packet")
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInterface tests that packets are not dropped from interface which
+// is not matching the interface name in the iptables rule.
+type FilterInputInterface struct{ localCase }
+
+var _ TestCase = FilterInputInterface{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterface) Name() string {
+ return "FilterInputInterface"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+ if err := listenUDP(ctx, acceptPort); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err)
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInterfaceBeginsWith tests that packets are dropped from an
+// interface which begins with the given interface name.
+type FilterInputInterfaceBeginsWith struct{ localCase }
+
+var _ TestCase = FilterInputInterfaceBeginsWith{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterfaceBeginsWith) Name() string {
+ return "FilterInputInterfaceBeginsWith"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "e+", "-j", "DROP"); err != nil {
+ return err
+ }
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, acceptPort); err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ return fmt.Errorf("error reading: %w", err)
+ }
+ return fmt.Errorf("packets should have been dropped, but got a packet")
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInterfaceInvertDrop tests that we selectively drop packets from
+// interface not matching the interface name.
+type FilterInputInterfaceInvertDrop struct{ baseCase }
+
+var _ TestCase = FilterInputInterfaceInvertDrop{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterfaceInvertDrop) Name() string {
+ return "FilterInputInterfaceInvertDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ return fmt.Errorf("error reading: %w", err)
+ }
+ return fmt.Errorf("connection on port %d should not be accepted, but was accepted", acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err != nil {
+ var operr *net.OpError
+ if errors.As(err, &operr) && operr.Timeout() {
+ return nil
+ }
+ return fmt.Errorf("error connecting: %w", err)
+ }
+ return fmt.Errorf("connection destined to port %d should not be accepted, but was accepted", acceptPort)
+}
+
+// FilterInputInterfaceInvertAccept tests that we can selectively accept packets
+// not matching the specific incoming interface.
+type FilterInputInterfaceInvertAccept struct{ baseCase }
+
+var _ TestCase = FilterInputInterfaceInvertAccept{}
+
+// Name implements TestCase.Name.
+func (FilterInputInterfaceInvertAccept) Name() string {
+ return "FilterInputInterfaceInvertAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 4733146c0..ef92e3fff 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -392,6 +392,30 @@ func TestInputInvertSource(t *testing.T) {
singleTest(t, FilterInputInvertSource{})
}
+func TestInputInterfaceAccept(t *testing.T) {
+ singleTest(t, FilterInputInterfaceAccept{})
+}
+
+func TestInputInterfaceDrop(t *testing.T) {
+ singleTest(t, FilterInputInterfaceDrop{})
+}
+
+func TestInputInterface(t *testing.T) {
+ singleTest(t, FilterInputInterface{})
+}
+
+func TestInputInterfaceBeginsWith(t *testing.T) {
+ singleTest(t, FilterInputInterfaceBeginsWith{})
+}
+
+func TestInputInterfaceInvertDrop(t *testing.T) {
+ singleTest(t, FilterInputInterfaceInvertDrop{})
+}
+
+func TestInputInterfaceInvertAccept(t *testing.T) {
+ singleTest(t, FilterInputInterfaceInvertAccept{})
+}
+
func TestFilterAddrs(t *testing.T) {
tcs := []struct {
ipv6 bool
@@ -424,3 +448,11 @@ func TestNATPreOriginalDst(t *testing.T) {
func TestNATOutOriginalDst(t *testing.T) {
singleTest(t, NATOutOriginalDst{})
}
+
+func TestNATPreRECVORIGDSTADDR(t *testing.T) {
+ singleTest(t, NATPreRECVORIGDSTADDR{})
+}
+
+func TestNATOutRECVORIGDSTADDR(t *testing.T) {
+ singleTest(t, NATOutRECVORIGDSTADDR{})
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index a6ec5cca3..4cd770a65 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -171,7 +171,7 @@ func connectTCP(ctx context.Context, ip net.IP, port int) error {
return err
}
if err := testutil.PollContext(ctx, callback); err != nil {
- return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err)
+ return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %w", port, err)
}
return nil
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 495241482..c3874240f 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -20,6 +20,9 @@ import (
"fmt"
"net"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const redirectPort = 42
@@ -43,6 +46,8 @@ func init() {
RegisterTestCase(NATLoopbackSkipsPrerouting{})
RegisterTestCase(NATPreOriginalDst{})
RegisterTestCase(NATOutOriginalDst{})
+ RegisterTestCase(NATPreRECVORIGDSTADDR{})
+ RegisterTestCase(NATOutRECVORIGDSTADDR{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -538,9 +543,9 @@ func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool)
}
func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error {
- // The net package doesn't give guarantee access to the connection's
+ // The net package doesn't give guaranteed access to the connection's
// underlying FD, and thus we cannot call getsockopt. We have to use
- // traditional syscalls for SO_ORIGINAL_DST.
+ // traditional syscalls.
// Create the listening socket, bind, listen, and accept.
family := syscall.AF_INET
@@ -609,36 +614,14 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
if err != nil {
return err
}
- // The original destination could be any of our IPs.
- for _, dst := range originalDsts {
- want := syscall.RawSockaddrInet6{
- Family: syscall.AF_INET6,
- Port: htons(dropPort),
- }
- copy(want.Addr[:], dst.To16())
- if got == want {
- return nil
- }
- }
- return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ return addrMatches6(got, originalDsts, dropPort)
}
got, err := originalDestination4(connFD)
if err != nil {
return err
}
- // The original destination could be any of our IPs.
- for _, dst := range originalDsts {
- want := syscall.RawSockaddrInet4{
- Family: syscall.AF_INET,
- Port: htons(dropPort),
- }
- copy(want.Addr[:], dst.To4())
- if got == want {
- return nil
- }
- }
- return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ return addrMatches4(got, originalDsts, dropPort)
}
// loopbackTests runs an iptables rule and ensures that packets sent to
@@ -662,3 +645,233 @@ func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) e
return err
}
}
+
+// NATPreRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT
+// address on the PREROUTING chain.
+type NATPreRECVORIGDSTADDR struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreRECVORIGDSTADDR) Name() string {
+ return "NATPreRECVORIGDSTADDR"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := recvWithRECVORIGDSTADDR(ctx, ipv6, nil, redirectPort); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATOutRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT
+// address on the OUTPUT chain.
+type NATOutRECVORIGDSTADDR struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATOutRECVORIGDSTADDR) Name() string {
+ return "NATOutRECVORIGDSTADDR"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ sendCh := make(chan error)
+ go func() {
+ // Packets will be sent to a non-container IP and redirected
+ // back to the container.
+ sendCh <- sendUDPLoop(ctx, ip, acceptPort)
+ }()
+
+ expectedIP := &net.IP{127, 0, 0, 1}
+ if ipv6 {
+ expectedIP = &net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ }
+ if err := recvWithRECVORIGDSTADDR(ctx, ipv6, expectedIP, redirectPort); err != nil {
+ return err
+ }
+
+ select {
+ case err := <-sendCh:
+ return err
+ default:
+ return nil
+ }
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP, port uint16) error {
+ // The net package doesn't give guaranteed access to a connection's
+ // underlying FD, and thus we cannot call getsockopt. We have to use
+ // traditional syscalls for IP_RECVORIGDSTADDR.
+
+ // Create the listening socket.
+ var (
+ family = syscall.AF_INET
+ level = syscall.SOL_IP
+ option = syscall.IP_RECVORIGDSTADDR
+ bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{
+ Port: int(port),
+ Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
+ }
+ )
+ if ipv6 {
+ family = syscall.AF_INET6
+ level = syscall.SOL_IPV6
+ option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package.
+ bindAddr = &syscall.SockaddrInet6{
+ Port: int(port),
+ Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
+ }
+ }
+ sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err)
+ }
+ defer syscall.Close(sockfd)
+
+ if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err)
+ }
+
+ // Enable IP_RECVORIGDSTADDR.
+ if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil {
+ return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err)
+ }
+
+ addrCh := make(chan interface{})
+ errCh := make(chan error)
+ go func() {
+ var addr interface{}
+ var err error
+ if ipv6 {
+ addr, err = recvOrigDstAddr6(sockfd)
+ } else {
+ addr, err = recvOrigDstAddr4(sockfd)
+ }
+ if err != nil {
+ errCh <- err
+ } else {
+ addrCh <- addr
+ }
+ }()
+
+ // Wait to receive a packet.
+ var addr interface{}
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errCh:
+ return err
+ case addr = <-addrCh:
+ }
+
+ // Get a list of local IPs to verify that the packet now appears to have
+ // been sent to us.
+ var localAddrs []net.IP
+ if expectedDst != nil {
+ localAddrs = []net.IP{*expectedDst}
+ } else {
+ localAddrs, err = getInterfaceAddrs(ipv6)
+ if err != nil {
+ return fmt.Errorf("failed to get local interfaces: %w", err)
+ }
+ }
+
+ // Verify that the address has the post-NAT port and address.
+ if ipv6 {
+ return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort)
+ }
+ return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort)
+}
+
+func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) {
+ buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4)
+ if err != nil {
+ return syscall.RawSockaddrInet4{}, err
+ }
+ var addr syscall.RawSockaddrInet4
+ binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ return addr, nil
+}
+
+func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) {
+ buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6)
+ if err != nil {
+ return syscall.RawSockaddrInet6{}, err
+ }
+ var addr syscall.RawSockaddrInet6
+ binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ return addr, nil
+}
+
+func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) {
+ buf := make([]byte, 64)
+ oob := make([]byte, syscall.CmsgSpace(addrSize))
+ for {
+ _, oobn, _, _, err := syscall.Recvmsg(
+ sockfd,
+ buf, // Message buffer.
+ oob, // Out-of-band buffer.
+ 0) // Flags.
+ if errors.Is(err, syscall.EINTR) {
+ continue
+ }
+ if err != nil {
+ return nil, fmt.Errorf("failed when calling Recvmsg: %w", err)
+ }
+ oob = oob[:oobn]
+
+ // Parse out the control message.
+ msgs, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse control message: %w", err)
+ }
+ return msgs[0].Data, nil
+ }
+}
+
+func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error {
+ for _, wantAddr := range wantAddrs {
+ want := syscall.RawSockaddrInet4{
+ Family: syscall.AF_INET,
+ Port: htons(port),
+ }
+ copy(want.Addr[:], wantAddr.To4())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
+}
+
+func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error {
+ for _, wantAddr := range wantAddrs {
+ want := syscall.RawSockaddrInet6{
+ Family: syscall.AF_INET6,
+ Port: htons(port),
+ }
+ copy(want.Addr[:], wantAddr.To16())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
+}
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
index ccf1c735f..0be14ca3e 100644
--- a/test/packetimpact/dut/BUILD
+++ b/test/packetimpact/dut/BUILD
@@ -14,6 +14,7 @@ cc_binary(
grpcpp,
"//test/packetimpact/proto:posix_server_cc_grpc_proto",
"//test/packetimpact/proto:posix_server_cc_proto",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -24,5 +25,6 @@ cc_binary(
grpcpp,
"//test/packetimpact/proto:posix_server_cc_grpc_proto",
"//test/packetimpact/proto:posix_server_cc_proto",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index 4de8540f6..eba21df12 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -16,6 +16,7 @@
#include <getopt.h>
#include <netdb.h>
#include <netinet/in.h>
+#include <poll.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -30,6 +31,7 @@
#include "include/grpcpp/security/server_credentials.h"
#include "include/grpcpp/server_builder.h"
#include "include/grpcpp/server_context.h"
+#include "absl/strings/str_format.h"
#include "test/packetimpact/proto/posix_server.grpc.pb.h"
#include "test/packetimpact/proto/posix_server.pb.h"
@@ -256,6 +258,44 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status Poll(::grpc::ServerContext *context,
+ const ::posix_server::PollRequest *request,
+ ::posix_server::PollResponse *response) override {
+ std::vector<struct pollfd> pfds;
+ pfds.reserve(request->pfds_size());
+ for (const auto &pfd : request->pfds()) {
+ pfds.push_back({
+ .fd = pfd.fd(),
+ .events = static_cast<short>(pfd.events()),
+ });
+ }
+ int ret = ::poll(pfds.data(), pfds.size(), request->timeout_millis());
+
+ response->set_ret(ret);
+ if (ret < 0) {
+ response->set_errno_(errno);
+ } else {
+ // Only pollfds that have non-empty revents are returned, the client can't
+ // rely on indexes of the request array.
+ for (const auto &pfd : pfds) {
+ if (pfd.revents) {
+ auto *proto_pfd = response->add_pfds();
+ proto_pfd->set_fd(pfd.fd);
+ proto_pfd->set_events(pfd.revents);
+ }
+ }
+ if (int ready = response->pfds_size(); ret != ready) {
+ return ::grpc::Status(
+ ::grpc::StatusCode::INTERNAL,
+ absl::StrFormat(
+ "poll's return value(%d) doesn't match the number of "
+ "file descriptors that are actually ready(%d)",
+ ret, ready));
+ }
+ }
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status Send(::grpc::ServerContext *context,
const ::posix_server::SendRequest *request,
::posix_server::SendResponse *response) override {
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
index f32ed54ef..b4c68764a 100644
--- a/test/packetimpact/proto/posix_server.proto
+++ b/test/packetimpact/proto/posix_server.proto
@@ -142,6 +142,25 @@ message ListenResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
+// The events field is overloaded: when used for request, it is copied into the
+// events field of posix struct pollfd; when used for response, it is filled by
+// the revents field from the posix struct pollfd.
+message PollFd {
+ int32 fd = 1;
+ uint32 events = 2;
+}
+
+message PollRequest {
+ repeated PollFd pfds = 1;
+ int32 timeout_millis = 2;
+}
+
+message PollResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ repeated PollFd pfds = 3;
+}
+
message SendRequest {
int32 sockfd = 1;
bytes buf = 2;
@@ -226,6 +245,10 @@ service Posix {
rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse);
// Call listen() on the DUT.
rpc Listen(ListenRequest) returns (ListenResponse);
+ // Call poll() on the DUT. Only pollfds that have non-empty revents are
+ // returned, the only way to tie the response back to the original request
+ // is using the fd number.
+ rpc Poll(PollRequest) returns (PollResponse);
// Call send() on the DUT.
rpc Send(SendRequest) returns (SendResponse);
// Call sendto() on the DUT.
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index c6c95546a..a7c46781f 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -175,9 +175,6 @@ ALL_TESTS = [
name = "udp_discard_mcast_source_addr",
),
PacketimpactTestInfo(
- name = "udp_recv_mcast_bcast",
- ),
- PacketimpactTestInfo(
name = "udp_any_addr_recv_unicast",
),
PacketimpactTestInfo(
@@ -277,6 +274,13 @@ ALL_TESTS = [
PacketimpactTestInfo(
name = "tcp_rcv_buf_space",
),
+ PacketimpactTestInfo(
+ name = "tcp_rack",
+ expect_netstack_failure = True,
+ ),
+ PacketimpactTestInfo(
+ name = "tcp_info",
+ ),
]
def validate_all_tests():
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 576577310..1453ac232 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -1008,6 +1008,13 @@ func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
return sa
}
+// SrcPort returns the source port of this connection.
+func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 {
+ t.Helper()
+
+ return *conn.udpState(t).out.SrcPort
+}
+
// Send sends a packet with reasonable defaults, potentially overriding the UDP
// layer and adding additionLayers.
func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
@@ -1024,6 +1031,11 @@ func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...
(*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ (*Connection)(conn).send(t, overrideLayers, additionalLayers...)
+}
+
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
@@ -1053,6 +1065,14 @@ func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout
return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv4) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ return (*Connection)(conn).ExpectFrame(t, frame, timeout)
+}
+
// Close frees associated resources held by the UDPIPv4 connection.
func (conn *UDPIPv4) Close(t *testing.T) {
t.Helper()
@@ -1136,6 +1156,13 @@ func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6
return sa
}
+// SrcPort returns the source port of this connection.
+func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 {
+ t.Helper()
+
+ return *conn.udpState(t).out.SrcPort
+}
+
// Send sends a packet with reasonable defaults, potentially overriding the UDP
// layer and adding additionLayers.
func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
@@ -1152,6 +1179,11 @@ func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers .
(*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ (*Connection)(conn).send(t, overrideLayers, additionalLayers...)
+}
+
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
@@ -1181,6 +1213,14 @@ func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout
return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv6) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ return (*Connection)(conn).ExpectFrame(t, frame, timeout)
+}
+
// Close frees associated resources held by the UDPIPv6 connection.
func (conn *UDPIPv6) Close(t *testing.T) {
t.Helper()
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 66a0255b8..aedcf6013 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -486,6 +486,56 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backl
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
+// Poll calls poll on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over error handling is needed, use PollWithErrno.
+// Only pollfds with non-empty revents are returned, the only way to tie the
+// response back to the original request is using the fd number.
+func (dut *DUT) Poll(t *testing.T, pfds []unix.PollFd, timeout time.Duration) []unix.PollFd {
+ t.Helper()
+
+ ctx := context.Background()
+ var cancel context.CancelFunc
+ if timeout >= 0 {
+ ctx, cancel = context.WithTimeout(ctx, timeout+RPCTimeout)
+ defer cancel()
+ }
+ ret, result, err := dut.PollWithErrno(ctx, t, pfds, timeout)
+ if ret < 0 {
+ t.Fatalf("failed to poll: %s", err)
+ }
+ return result
+}
+
+// PollWithErrno calls poll on the DUT.
+func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.PollFd, timeout time.Duration) (int32, []unix.PollFd, error) {
+ t.Helper()
+
+ req := pb.PollRequest{
+ TimeoutMillis: int32(timeout.Milliseconds()),
+ }
+ for _, pfd := range pfds {
+ req.Pfds = append(req.Pfds, &pb.PollFd{
+ Fd: pfd.Fd,
+ Events: uint32(pfd.Events),
+ })
+ }
+ resp, err := dut.posixServer.Poll(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Poll: %s", err)
+ }
+ if ret, npfds := resp.GetRet(), len(resp.GetPfds()); ret >= 0 && int(ret) != npfds {
+ t.Fatalf("nonsensical poll response: ret(%d) != len(pfds)(%d)", ret, npfds)
+ }
+ var result []unix.PollFd
+ for _, protoPfd := range resp.GetPfds() {
+ result = append(result, unix.PollFd{
+ Fd: protoPfd.GetFd(),
+ Revents: int16(protoPfd.GetEvents()),
+ })
+ }
+ return resp.GetRet(), result, syscall.Errno(resp.GetErrno_())
+}
+
// Send calls send on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendWithErrno.
@@ -544,7 +594,7 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32,
}
resp, err := dut.posixServer.SendTo(ctx, &req)
if err != nil {
- t.Fatalf("faled to call SendTo: %s", err)
+ t.Fatalf("failed to call SendTo: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index b1b3c578b..baa3ae5e9 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -38,18 +38,6 @@ packetimpact_testbench(
)
packetimpact_testbench(
- name = "udp_recv_mcast_bcast",
- srcs = ["udp_recv_mcast_bcast_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//test/packetimpact/testbench",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-packetimpact_testbench(
name = "udp_any_addr_recv_unicast",
srcs = ["udp_any_addr_recv_unicast_test.go"],
deps = [
@@ -340,6 +328,8 @@ packetimpact_testbench(
name = "udp_send_recv_dgram",
srcs = ["udp_send_recv_dgram_test.go"],
deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
"//test/packetimpact/testbench",
"@com_github_google_go_cmp//cmp:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
@@ -376,6 +366,33 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_rack",
+ srcs = ["tcp_rack_test.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/usermem",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
+ name = "tcp_info",
+ srcs = ["tcp_info_test.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/tcpip/header",
+ "//pkg/usermem",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
index d2203082d..ee050e2c6 100644
--- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
@@ -45,8 +45,6 @@ func TestIPv4FragmentReassembly(t *testing.T) {
ipPayloadLen int
fragments []fragmentInfo
expectReply bool
- skip bool
- skipReason string
}{
{
description: "basic reassembly",
@@ -78,8 +76,6 @@ func TestIPv4FragmentReassembly(t *testing.T) {
{offset: 2000, size: 1000, id: 7, more: 0},
},
expectReply: true,
- skip: true,
- skipReason: "gvisor.dev/issues/4971",
},
{
description: "fragment subset",
@@ -91,8 +87,6 @@ func TestIPv4FragmentReassembly(t *testing.T) {
{offset: 2000, size: 1000, id: 8, more: 0},
},
expectReply: true,
- skip: true,
- skipReason: "gvisor.dev/issues/4971",
},
{
description: "fragment overlap",
@@ -104,16 +98,10 @@ func TestIPv4FragmentReassembly(t *testing.T) {
{offset: 2000, size: 1000, id: 9, more: 0},
},
expectReply: false,
- skip: true,
- skipReason: "gvisor.dev/issues/4971",
},
}
for _, test := range tests {
- if test.skip {
- t.Skip("%s test skipped: %s", test.description, test.skipReason)
- continue
- }
t.Run(test.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
conn := dut.Net.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{})
diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go
new file mode 100644
index 000000000..b66e8f609
--- /dev/null
+++ b/test/packetimpact/tests/tcp_info_test.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_info_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+func TestTCPInfo(t *testing.T) {
+ // Create a socket, listen, TCP connect, and accept.
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ // Send and receive sample data.
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+ dut.Send(t, acceptFD, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ info := linux.TCPInfo{}
+ infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo))
+ binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+
+ rtt := time.Duration(info.RTT) * time.Microsecond
+ rttvar := time.Duration(info.RTTVar) * time.Microsecond
+ rto := time.Duration(info.RTO) * time.Microsecond
+ if rtt == 0 || rttvar == 0 || rto == 0 {
+ t.Errorf("expected rtt(%v), rttvar(%v) and rto(%v) to be greater than zero", rtt, rttvar, rto)
+ }
+ if info.ReordSeen != 0 {
+ t.Errorf("expected the connection to not have any reordering, got: %v want: 0", info.ReordSeen)
+ }
+ if info.SndCwnd == 0 {
+ t.Errorf("expected send congestion window to be greater than zero")
+ }
+ if info.CaState != linux.TCP_CA_Open {
+ t.Errorf("expected the connection to be in open state, got: %v want: %v", info.CaState, linux.TCP_CA_Open)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Check the congestion control state and send congestion window after
+ // retransmission timeout.
+ seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
+ dut.Send(t, acceptFD, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+
+ // Expect retransmission of the packet within 1.5*RTO.
+ timeout := time.Duration(float64(info.RTO)*1.5) * time.Microsecond
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+
+ info = linux.TCPInfo{}
+ infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo))
+ binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ if info.CaState != linux.TCP_CA_Loss {
+ t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss)
+ }
+ if info.SndCwnd != 1 {
+ t.Errorf("expected send congestion window to be 1, got: %v %v", info.SndCwnd)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index f0af5352d..c874a8912 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -34,6 +34,21 @@ func TestTcpNoAcceptCloseReset(t *testing.T) {
conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
conn.Connect(t)
defer conn.Close(t)
+ // We need to wait for POLLIN event on listenFd to know the connection is
+ // established. Otherwise there could be a race when we issue the Close
+ // command prior to the DUT receiving the last ack of the handshake and
+ // it will only respond RST instead of RST+ACK.
+ timeout := time.Second
+ pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, timeout)
+ if n := len(pfds); n != 1 {
+ t.Fatalf("poll returned %d ready file descriptors, expected 1", n)
+ }
+ if readyFd := pfds[0].Fd; readyFd != listenFd {
+ t.Fatalf("poll returned an fd %d that was not requested (%d)", readyFd, listenFd)
+ }
+ if got, want := pfds[0].Revents, int16(unix.POLLIN); got&want == 0 {
+ t.Fatalf("poll returned no events in our interest, got: %#b, want: %#b", got, want)
+ }
dut.Close(t, listenFd)
if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
t.Fatalf("expected a RST-ACK packet but got none: %s", err)
diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go
new file mode 100644
index 000000000..0a2381c97
--- /dev/null
+++ b/test/packetimpact/tests/tcp_rack_test.go
@@ -0,0 +1,221 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_rack_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+const (
+ // payloadSize is the size used to send packets.
+ payloadSize = header.TCPDefaultMSS
+
+ // simulatedRTT is the time delay between packets sent and acked to
+ // increase the RTT.
+ simulatedRTT = 30 * time.Millisecond
+
+ // numPktsForRTT is the number of packets sent and acked to establish
+ // RTT.
+ numPktsForRTT = 10
+)
+
+func createSACKConnection(t *testing.T) (testbench.DUT, testbench.TCPIPv4, int32, int32) {
+ dut := testbench.NewDUT(t)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+
+ // Enable SACK.
+ opts := make([]byte, 40)
+ optsOff := 0
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeSACKPermittedOption(opts[optsOff:])
+
+ conn.ConnectWithOptions(t, opts[:optsOff])
+ acceptFd, _ := dut.Accept(t, listenFd)
+ return dut, conn, acceptFd, listenFd
+}
+
+func closeSACKConnection(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, acceptFd, listenFd int32) {
+ dut.Close(t, acceptFd)
+ dut.Close(t, listenFd)
+ conn.Close(t)
+}
+
+func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto time.Duration) {
+ info := linux.TCPInfo{}
+ ret := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo))
+ binary.Unmarshal(ret, usermem.ByteOrder, &info)
+ return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond
+}
+
+func sendAndReceive(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, numPkts int, acceptFd int32, sendACK bool) time.Time {
+ seqNum1 := *conn.RemoteSeqNum(t)
+ payload := make([]byte, payloadSize)
+ var lastSent time.Time
+ for i, sn := 0, seqNum1; i < numPkts; i++ {
+ lastSent = time.Now()
+ dut.Send(t, acceptFd, payload, 0)
+ gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, time.Second)
+ if err != nil {
+ t.Fatalf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Fatalf("#%d: expected a packet within a second but got none", i+1)
+ }
+ sn.UpdateForward(seqnum.Size(payloadSize))
+
+ if sendACK {
+ time.Sleep(simulatedRTT)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))})
+ }
+ }
+ return lastSent
+}
+
+// TestRACKTLPAllPacketsLost tests TLP when an entire flight of data is lost.
+func TestRACKTLPAllPacketsLost(t *testing.T) {
+ dut, conn, acceptFd, listenFd := createSACKConnection(t)
+ seqNum1 := *conn.RemoteSeqNum(t)
+
+ // Send ACK for data packets to establish RTT.
+ sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */)
+ seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize))
+
+ // We are not sending ACK for these packets.
+ const numPkts = 5
+ lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */)
+
+ // Probe Timeout (PTO) should be two times RTT. Check that the last
+ // packet is retransmitted after probe timeout.
+ rtt, _ := getRTTAndRTO(t, dut, acceptFd)
+ pto := rtt * 2
+ // We expect the 5th packet (the last unacknowledged packet) to be
+ // retransmitted.
+ tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize))
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s %v %v", err, rtt, pto)
+ }
+ diff := time.Now().Sub(lastSent)
+ if diff < pto {
+ t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto)
+ }
+ closeSACKConnection(t, dut, conn, acceptFd, listenFd)
+}
+
+// TestRACKTLPLost tests TLP when there are tail losses.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.4
+func TestRACKTLPLost(t *testing.T) {
+ dut, conn, acceptFd, listenFd := createSACKConnection(t)
+ seqNum1 := *conn.RemoteSeqNum(t)
+
+ // Send ACK for data packets to establish RTT.
+ sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */)
+ seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize))
+
+ // We are not sending ACK for these packets.
+ const numPkts = 10
+ lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */)
+
+ // Cumulative ACK for #[1-5] packets.
+ ackNum := seqNum1.Add(seqnum.Size(6 * payloadSize))
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))})
+
+ // Probe Timeout (PTO) should be two times RTT. Check that the last
+ // packet is retransmitted after probe timeout.
+ rtt, _ := getRTTAndRTO(t, dut, acceptFd)
+ pto := rtt * 2
+ // We expect the 10th packet (the last unacknowledged packet) to be
+ // retransmitted.
+ tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize))
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ diff := time.Now().Sub(lastSent)
+ if diff < pto {
+ t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto)
+ }
+ closeSACKConnection(t, dut, conn, acceptFd, listenFd)
+}
+
+// TestRACKTLPWithSACK tests TLP by acknowledging out of order packets.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-8.1
+func TestRACKTLPWithSACK(t *testing.T) {
+ dut, conn, acceptFd, listenFd := createSACKConnection(t)
+ seqNum1 := *conn.RemoteSeqNum(t)
+
+ // Send ACK for data packets to establish RTT.
+ sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */)
+ seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize))
+
+ // We are not sending ACK for these packets.
+ const numPkts = 3
+ lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */)
+
+ // SACK for #2 packet.
+ sackBlock := make([]byte, 40)
+ start := seqNum1.Add(seqnum.Size(payloadSize))
+ end := start.Add(seqnum.Size(payloadSize))
+ sbOff := 0
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
+ start, end,
+ }}, sackBlock[sbOff:])
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+
+ // RACK marks #1 packet as lost and retransmits it.
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // ACK for #1 packet.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))})
+
+ // Probe Timeout (PTO) should be two times RTT. TLP will trigger for #3
+ // packet. RACK adds an additional timeout of 200ms if the number of
+ // outstanding packets is equal to 1.
+ rtt, rto := getRTTAndRTO(t, dut, acceptFd)
+ pto := rtt*2 + (200 * time.Millisecond)
+ if rto < pto {
+ pto = rto
+ }
+ // We expect the 3rd packet (the last unacknowledged packet) to be
+ // retransmitted.
+ tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize))
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ diff := time.Now().Sub(lastSent)
+ if diff < pto {
+ t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto)
+ }
+ closeSACKConnection(t, dut, conn, acceptFd, listenFd)
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index 1ab9ee1b2..b15b8fc25 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -66,33 +66,39 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
- startProbeDuration := time.Second
- current := startProbeDuration
- first := time.Now()
// Ask the dut to send out data.
dut.Send(t, acceptFd, sampleData, 0)
+
+ var prev time.Duration
// Expect the dut to keep the connection alive as long as the remote is
// acknowledging the zero-window probes.
- for i := 0; i < 5; i++ {
+ for i := 1; i <= 5; i++ {
start := time.Now()
// Expect zero-window probe with a timeout which is a function of the typical
// first retransmission time. The retransmission times is supposed to
// exponentially increase.
- if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Duration(i)*time.Second); err != nil {
t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i)
}
- if i == 0 {
- startProbeDuration = time.Now().Sub(first)
- current = 2 * startProbeDuration
+ if i == 1 {
+ // Skip the first probe as computing transmit time for that is
+ // non-deterministic because of the arbitrary time taken for
+ // the dut to receive a send command and issue a send.
continue
}
- // Check if the probes came at exponentially increasing intervals.
- if got, want := time.Since(start), current-startProbeDuration; got < want {
+
+ // Check if the time taken to receive the probe from the dut is
+ // increasing exponentially. To avoid flakes, use a correction
+ // factor for the expected duration which accounts for any
+ // scheduling non-determinism.
+ const timeCorrection = 200 * time.Millisecond
+ got := time.Since(start)
+ if want := (2 * prev) - timeCorrection; prev != 0 && got < want {
t.Errorf("got zero probe %d after %s, want >= %s", i, got, want)
}
+ prev = got
// Acknowledge the zero-window probes from the dut.
conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
- current *= 2
}
// Advertize non-zero window.
conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
deleted file mode 100644
index b29c07825..000000000
--- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp_recv_mcast_bcast_test
-
-import (
- "context"
- "flag"
- "fmt"
- "net"
- "syscall"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/test/packetimpact/testbench"
-)
-
-func init() {
- testbench.Initialize(flag.CommandLine)
-}
-
-func TestUDPRecvMcastBcast(t *testing.T) {
- dut := testbench.NewDUT(t)
- subnetBcastAddr := broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32))
- for _, v := range []struct {
- bound, to net.IP
- }{
- {bound: net.IPv4zero, to: subnetBcastAddr},
- {bound: net.IPv4zero, to: net.IPv4bcast},
- {bound: net.IPv4zero, to: net.IPv4allsys},
-
- {bound: subnetBcastAddr, to: subnetBcastAddr},
-
- // FIXME(gvisor.dev/issue/4896): Previously by the time subnetBcastAddr is
- // created, IPv4PrefixLength is still 0 because genPseudoFlags is not called
- // yet, it was only called in NewDUT, so the test didn't do what the author
- // original intended to and becomes failing because we process all flags at
- // the very beginning.
- //
- // {bound: subnetBcastAddr, to: net.IPv4bcast},
-
- {bound: net.IPv4bcast, to: net.IPv4bcast},
- {bound: net.IPv4allsys, to: net.IPv4allsys},
- } {
- t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) {
- boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound)
- defer dut.Close(t, boundFD)
- conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close(t)
-
- payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
- conn.SendIP(
- t,
- testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))},
- testbench.UDP{},
- &testbench.Payload{Bytes: payload},
- )
- got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) {
- dut := testbench.NewDUT(t)
- boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv4)
- dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0})
- defer dut.Close(t, boundFD)
- conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close(t)
-
- for _, to := range []net.IP{
- broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)),
- net.IPv4(255, 255, 255, 255),
- net.IPv4(224, 0, 0, 1),
- } {
- t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) {
- payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
- conn.SendIP(
- t,
- testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))},
- testbench.UDP{},
- &testbench.Payload{Bytes: payload},
- )
- ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0)
- if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
- t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
- }
- })
- }
-}
-
-func broadcastAddr(ip net.IP, mask net.IPMask) net.IP {
- result := make(net.IP, net.IPv4len)
- ip4 := ip.To4()
- for i := range ip4 {
- result[i] = ip4[i] | ^mask[i]
- }
- return result
-}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index 7ee2c8014..6e45cb143 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -15,13 +15,18 @@
package udp_send_recv_dgram_test
import (
+ "context"
"flag"
+ "fmt"
"net"
+ "syscall"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/test/packetimpact/testbench"
)
@@ -30,74 +35,295 @@ func init() {
}
type udpConn interface {
- Send(*testing.T, testbench.UDP, ...testbench.Layer)
- ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error)
- Drain(*testing.T)
+ SrcPort(*testing.T) uint16
+ SendFrame(*testing.T, testbench.Layers, ...testbench.Layer)
+ ExpectFrame(*testing.T, testbench.Layers, time.Duration) (testbench.Layers, error)
Close(*testing.T)
}
+type testCase struct {
+ bindTo, sendTo net.IP
+ sendToBroadcast, bindToDevice, expectData bool
+}
+
func TestUDP(t *testing.T) {
dut := testbench.NewDUT(t)
+ subnetBcast := func() net.IP {
+ subnet := (&tcpip.AddressWithPrefix{
+ Address: tcpip.Address(dut.Net.RemoteIPv4.To4()),
+ PrefixLen: dut.Net.IPv4PrefixLength,
+ }).Subnet()
+ return net.IP(subnet.Broadcast())
+ }()
- for _, isIPv4 := range []bool{true, false} {
- ipVersionName := "IPv6"
- if isIPv4 {
- ipVersionName = "IPv4"
- }
- t.Run(ipVersionName, func(t *testing.T) {
- var addr net.IP
- if isIPv4 {
- addr = dut.Net.RemoteIPv4
- } else {
- addr = dut.Net.RemoteIPv6
+ t.Run("Send", func(t *testing.T) {
+ var testCases []testCase
+ // Test every valid combination of bound/unbound, broadcast/multicast/unicast
+ // bound/destination address, and bound/not-bound to device.
+ for _, bindTo := range []net.IP{
+ nil, // Do not bind.
+ net.IPv4zero,
+ net.IPv4bcast,
+ net.IPv4allsys,
+ subnetBcast,
+ dut.Net.RemoteIPv4,
+ dut.Net.RemoteIPv6,
+ } {
+ for _, sendTo := range []net.IP{
+ net.IPv4bcast,
+ net.IPv4allsys,
+ subnetBcast,
+ dut.Net.LocalIPv4,
+ dut.Net.LocalIPv6,
+ } {
+ // Cannot send to an IPv4 address from a socket bound to IPv6 (except for IPv4-mapped IPv6),
+ // and viceversa.
+ if bindTo != nil && ((bindTo.To4() == nil) != (sendTo.To4() == nil)) {
+ continue
+ }
+ for _, bindToDevice := range []bool{true, false} {
+ expectData := true
+ switch {
+ case bindTo.Equal(dut.Net.RemoteIPv4):
+ // If we're explicitly bound to an interface's unicast address,
+ // packets are always sent on that interface.
+ case bindToDevice:
+ // If we're explicitly bound to an interface, packets are always
+ // sent on that interface.
+ case !sendTo.Equal(net.IPv4bcast) && !sendTo.IsMulticast():
+ // If we're not sending to limited broadcast or multicast, the route table
+ // will be consulted and packets will be sent on the correct interface.
+ default:
+ expectData = false
+ }
+ testCases = append(
+ testCases,
+ testCase{
+ bindTo: bindTo,
+ sendTo: sendTo,
+ sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast),
+ bindToDevice: bindToDevice,
+ expectData: expectData,
+ },
+ )
+ }
}
- boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, addr)
- defer dut.Close(t, boundFD)
-
- var conn udpConn
- var localAddr unix.Sockaddr
- if isIPv4 {
- v4Conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- localAddr = v4Conn.LocalAddr(t)
- conn = &v4Conn
- } else {
- v6Conn := dut.Net.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- localAddr = v6Conn.LocalAddr(t, dut.Net.RemoteDevID)
- conn = &v6Conn
- }
- defer conn.Close(t)
-
- testCases := []struct {
- name string
- payload []byte
- }{
- {"emptypayload", nil},
- {"small payload", []byte("hello world")},
- {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)},
- // Even though UDP allows larger dgrams we don't test it here as
- // they need to be fragmented and written out as individual
- // frames.
+ }
+ for _, tc := range testCases {
+ boundTestCaseName := "unbound"
+ if tc.bindTo != nil {
+ boundTestCaseName = fmt.Sprintf("bindTo=%s", tc.bindTo)
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- t.Run("Send", func(t *testing.T) {
- conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload})
- got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload
- if diff := cmp.Diff(want, got); diff != "" {
- t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
+ t.Run(fmt.Sprintf("%s/sendTo=%s/bindToDevice=%t/expectData=%t", boundTestCaseName, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) {
+ runTestCase(
+ t,
+ dut,
+ tc,
+ func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) {
+ var destSockaddr unix.Sockaddr
+ if sendTo4 := tc.sendTo.To4(); sendTo4 != nil {
+ addr := unix.SockaddrInet4{
+ Port: int(conn.SrcPort(t)),
+ }
+ copy(addr.Addr[:], sendTo4)
+ destSockaddr = &addr
+ } else {
+ addr := unix.SockaddrInet6{
+ Port: int(conn.SrcPort(t)),
+ ZoneId: dut.Net.RemoteDevID,
+ }
+ copy(addr.Addr[:], tc.sendTo.To16())
+ destSockaddr = &addr
}
- })
- t.Run("Recv", func(t *testing.T) {
- conn.Drain(t)
- if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want {
- t.Fatalf("short write got: %d, want: %d", got, want)
+ if got, want := dut.SendTo(t, socketFD, payload, 0, destSockaddr), len(payload); int(got) != want {
+ t.Fatalf("got dut.SendTo = %d, want %d", got, want)
}
- if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil {
+ layers = append(layers, &testbench.Payload{
+ Bytes: payload,
+ })
+ _, err := conn.ExpectFrame(t, layers, time.Second)
+
+ if !tc.expectData && err == nil {
+ t.Fatal("received unexpected packet, socket is not bound to device")
+ }
+ if err != nil && tc.expectData {
t.Fatal(err)
}
- })
- })
+ },
+ )
+ })
+ }
+ })
+ t.Run("Recv", func(t *testing.T) {
+ // Test every valid combination of broadcast/multicast/unicast
+ // bound/destination address, and bound/not-bound to device.
+ var testCases []testCase
+ for _, addr := range []net.IP{
+ net.IPv4bcast,
+ net.IPv4allsys,
+ dut.Net.RemoteIPv4,
+ dut.Net.RemoteIPv6,
+ } {
+ for _, bindToDevice := range []bool{true, false} {
+ testCases = append(
+ testCases,
+ testCase{
+ bindTo: addr,
+ sendTo: addr,
+ sendToBroadcast: addr.Equal(subnetBcast) || addr.Equal(net.IPv4bcast),
+ bindToDevice: bindToDevice,
+ expectData: true,
+ },
+ )
}
- })
+ }
+ for _, bindTo := range []net.IP{
+ net.IPv4zero,
+ subnetBcast,
+ dut.Net.RemoteIPv4,
+ } {
+ for _, sendTo := range []net.IP{
+ subnetBcast,
+ net.IPv4bcast,
+ net.IPv4allsys,
+ } {
+ // TODO(gvisor.dev/issue/4896): Add bindTo=subnetBcast/sendTo=IPv4bcast
+ // and bindTo=subnetBcast/sendTo=IPv4allsys test cases.
+ if bindTo.Equal(subnetBcast) && (sendTo.Equal(net.IPv4bcast) || sendTo.IsMulticast()) {
+ continue
+ }
+ // Expect that a socket bound to a unicast address does not receive
+ // packets sent to an address other than the bound unicast address.
+ //
+ // Note: we cannot use net.IP.IsGlobalUnicast to test this condition
+ // because IsGlobalUnicast does not check whether the address is the
+ // subnet broadcast, and returns true in that case.
+ expectData := !bindTo.Equal(dut.Net.RemoteIPv4) || sendTo.Equal(dut.Net.RemoteIPv4)
+ for _, bindToDevice := range []bool{true, false} {
+ testCases = append(
+ testCases,
+ testCase{
+ bindTo: bindTo,
+ sendTo: sendTo,
+ sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast),
+ bindToDevice: bindToDevice,
+ expectData: expectData,
+ },
+ )
+ }
+ }
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("bindTo=%s/sendTo=%s/bindToDevice=%t/expectData=%t", tc.bindTo, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) {
+ runTestCase(
+ t,
+ dut,
+ tc,
+ func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) {
+ conn.SendFrame(t, layers, &testbench.Payload{Bytes: payload})
+
+ if tc.expectData {
+ got, want := dut.Recv(t, socketFD, int32(len(payload)+1), 0), payload
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
+ }
+ } else {
+ // Expected receive error, set a short receive timeout.
+ dut.SetSockOptTimeval(
+ t,
+ socketFD,
+ unix.SOL_SOCKET,
+ unix.SO_RCVTIMEO,
+ &unix.Timeval{
+ Sec: 1,
+ Usec: 0,
+ },
+ )
+ ret, recvPayload, errno := dut.RecvWithErrno(context.Background(), t, socketFD, 100, 0)
+ if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, recvPayload, errno)
+ }
+ }
+ },
+ )
+ })
+ }
+ })
+}
+
+func runTestCase(
+ t *testing.T,
+ dut testbench.DUT,
+ tc testCase,
+ runTc func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers),
+) {
+ var (
+ socketFD int32
+ outgoingUDP, incomingUDP testbench.UDP
+ )
+ if tc.bindTo != nil {
+ var remotePort uint16
+ socketFD, remotePort = dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, tc.bindTo)
+ outgoingUDP.DstPort = &remotePort
+ incomingUDP.SrcPort = &remotePort
+ } else {
+ // An unbound socket will auto-bind to INNADDR_ANY and a random
+ // port on sendto.
+ socketFD = dut.Socket(t, unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
+ }
+ defer dut.Close(t, socketFD)
+ if tc.bindToDevice {
+ dut.SetSockOpt(t, socketFD, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, []byte(dut.Net.RemoteDevName))
+ }
+
+ var ethernetLayer testbench.Ether
+ if tc.sendToBroadcast {
+ dut.SetSockOptInt(t, socketFD, unix.SOL_SOCKET, unix.SO_BROADCAST, 1)
+
+ // When sending to broadcast (subnet or limited), the expected ethernet
+ // address is also broadcast.
+ ethernetBroadcastAddress := header.EthernetBroadcastAddress
+ ethernetLayer.DstAddr = &ethernetBroadcastAddress
+ } else if tc.sendTo.IsMulticast() {
+ ethernetMulticastAddress := header.EthernetAddressFromMulticastIPv4Address(tcpip.Address(tc.sendTo.To4()))
+ ethernetLayer.DstAddr = &ethernetMulticastAddress
+ }
+ expectedLayers := testbench.Layers{&ethernetLayer}
+
+ var conn udpConn
+ if sendTo4 := tc.sendTo.To4(); sendTo4 != nil {
+ v4Conn := dut.Net.NewUDPIPv4(t, outgoingUDP, incomingUDP)
+ conn = &v4Conn
+ expectedLayers = append(
+ expectedLayers,
+ &testbench.IPv4{
+ DstAddr: testbench.Address(tcpip.Address(sendTo4)),
+ },
+ )
+ } else {
+ v6Conn := dut.Net.NewUDPIPv6(t, outgoingUDP, incomingUDP)
+ conn = &v6Conn
+ expectedLayers = append(
+ expectedLayers,
+ &testbench.IPv6{
+ DstAddr: testbench.Address(tcpip.Address(tc.sendTo)),
+ },
+ )
+ }
+ defer conn.Close(t)
+
+ expectedLayers = append(expectedLayers, &incomingUDP)
+ for _, v := range []struct {
+ name string
+ payload []byte
+ }{
+ {"emptypayload", nil},
+ {"small payload", []byte("hello world")},
+ {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ } {
+ runTc(t, dut, conn, socketFD, tc, v.payload, expectedLayers)
}
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 0da35f7be..6ee2b73c1 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -569,7 +569,7 @@ syscall_test(
# syscall_test(vfs2="True",test = "//test/syscalls/linux:sigaltstack_test")
syscall_test(
- test = "//test/syscalls/linux:sigiret_test",
+ test = "//test/syscalls/linux:sigreturn_test",
)
syscall_test(
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index d184712e3..2b4b6f348 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -582,6 +582,7 @@ cc_binary(
"//test/util:eventfd_util",
"//test/util:file_descriptor",
gtest,
+ "//test/util:fs_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
@@ -797,6 +798,7 @@ cc_binary(
linkstatic = 1,
deps = [
":socket_test_util",
+ "//test/util:capability_util",
"//test/util:cleanup",
"//test/util:eventfd_util",
"//test/util:file_descriptor",
@@ -807,6 +809,7 @@ cc_binary(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
gtest,
+ "//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:save_util",
@@ -978,6 +981,7 @@ cc_binary(
"//test/util:epoll_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
@@ -2191,11 +2195,11 @@ cc_binary(
)
cc_binary(
- name = "sigiret_test",
+ name = "sigreturn_test",
testonly = 1,
srcs = select_arch(
- amd64 = ["sigiret.cc"],
- arm64 = [],
+ amd64 = ["sigreturn_amd64.cc"],
+ arm64 = ["sigreturn_arm64.cc"],
),
linkstatic = 1,
deps = [
diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc
index a06b5cfd6..8233df0f8 100644
--- a/test/syscalls/linux/chmod.cc
+++ b/test/syscalls/linux/chmod.cc
@@ -98,6 +98,42 @@ TEST(ChmodTest, FchmodatBadF) {
ASSERT_THAT(fchmodat(-1, "foo", 0444, 0), SyscallFailsWithErrno(EBADF));
}
+TEST(ChmodTest, FchmodFileWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChmodTest, FchmodDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH));
+
+ ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChmodTest, FchmodatWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ const auto parent_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(GetAbsoluteTestTmpdir().c_str(), O_PATH | O_DIRECTORY));
+
+ ASSERT_THAT(
+ fchmodat(parent_fd.get(), std::string(Basename(temp_file.path())).c_str(),
+ 0444, 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(open(temp_file.path().c_str(), O_RDWR),
+ SyscallFailsWithErrno(EACCES));
+}
+
TEST(ChmodTest, FchmodatNotDir) {
ASSERT_THAT(fchmodat(-1, "", 0444, 0), SyscallFailsWithErrno(ENOENT));
}
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 5530ad18f..ff0d39343 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -48,6 +48,36 @@ TEST(ChownTest, FchownatBadF) {
ASSERT_THAT(fchownat(-1, "fff", 0, 0, 0), SyscallFailsWithErrno(EBADF));
}
+TEST(ChownTest, FchownFileWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChownTest, FchownDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH));
+
+ ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()),
+ SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ChownTest, FchownatWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ const auto dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH));
+ ASSERT_THAT(
+ fchownat(dirfd.get(), file.path().c_str(), geteuid(), getegid(), 0),
+ SyscallSucceeds());
+}
+
TEST(ChownTest, FchownatEmptyPath) {
const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const auto fd =
@@ -209,6 +239,14 @@ INSTANTIATE_TEST_SUITE_P(
owner, group, 0);
MaybeSave();
return errorFromReturn("fchownat-dirfd", rc);
+ },
+ [](const std::string& path, uid_t owner, gid_t group) -> PosixError {
+ ASSIGN_OR_RETURN_ERRNO(auto dirfd, Open(std::string(Dirname(path)),
+ O_DIRECTORY | O_PATH));
+ int rc = fchownat(dirfd.get(), std::string(Basename(path)).c_str(),
+ owner, group, 0);
+ MaybeSave();
+ return errorFromReturn("fchownat-opathdirfd", rc);
}));
} // namespace
diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc
index 4f773bc75..ba4e13fb9 100644
--- a/test/syscalls/linux/dup.cc
+++ b/test/syscalls/linux/dup.cc
@@ -18,6 +18,7 @@
#include "gtest/gtest.h"
#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -44,14 +45,6 @@ PosixErrorOr<FileDescriptor> Dup3(const FileDescriptor& fd, int target_fd,
return FileDescriptor(new_fd);
}
-void CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) {
- struct stat stat_result1, stat_result2;
- ASSERT_THAT(fstat(fd1.get(), &stat_result1), SyscallSucceeds());
- ASSERT_THAT(fstat(fd2.get(), &stat_result2), SyscallSucceeds());
- EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev);
- EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino);
-}
-
TEST(DupTest, Dup) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
@@ -59,7 +52,7 @@ TEST(DupTest, Dup) {
// Dup the descriptor and make sure it's the same file.
FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
ASSERT_NE(fd.get(), nfd.get());
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
}
TEST(DupTest, DupClearsCloExec) {
@@ -70,10 +63,24 @@ TEST(DupTest, DupClearsCloExec) {
// Duplicate the descriptor. Ensure that it doesn't have FD_CLOEXEC set.
FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
ASSERT_NE(fd.get(), nfd.get());
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
}
+TEST(DupTest, DupWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH));
+ int flags;
+ ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+
+ // Dup the descriptor and make sure it's the same file.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_NE(fd.get(), nfd.get());
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags));
+}
+
TEST(DupTest, Dup2) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
@@ -82,13 +89,13 @@ TEST(DupTest, Dup2) {
FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
ASSERT_NE(fd.get(), nfd.get());
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
// Dup over the file above.
int target_fd = nfd.release();
FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd));
EXPECT_EQ(target_fd, nfd2.get());
- CheckSameFile(fd, nfd2);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2));
}
TEST(DupTest, Dup2SameFD) {
@@ -99,6 +106,28 @@ TEST(DupTest, Dup2SameFD) {
ASSERT_THAT(dup2(fd.get(), fd.get()), SyscallSucceedsWithValue(fd.get()));
}
+TEST(DupTest, Dup2WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH));
+ int flags;
+ ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+
+ // Regular dup once.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+
+ ASSERT_NE(fd.get(), nfd.get());
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags));
+
+ // Dup over the file above.
+ int target_fd = nfd.release();
+ FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd));
+ EXPECT_EQ(target_fd, nfd2.get());
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2));
+ EXPECT_THAT(fcntl(nfd2.get(), F_GETFL), SyscallSucceedsWithValue(flags));
+}
+
TEST(DupTest, Dup3) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
@@ -106,16 +135,16 @@ TEST(DupTest, Dup3) {
// Regular dup once.
FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
ASSERT_NE(fd.get(), nfd.get());
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
// Dup over the file above, check that it has no CLOEXEC.
nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0));
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
// Dup over the file again, check that it does not CLOEXEC.
nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC));
- CheckSameFile(fd, nfd);
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
}
@@ -127,6 +156,32 @@ TEST(DupTest, Dup3FailsSameFD) {
ASSERT_THAT(dup3(fd.get(), fd.get(), 0), SyscallFailsWithErrno(EINVAL));
}
+TEST(DupTest, Dup3WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+ int flags;
+ ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+
+ // Regular dup once.
+ FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+ ASSERT_NE(fd.get(), nfd.get());
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+
+ // Dup over the file above, check that it has no CLOEXEC.
+ nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0));
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags));
+
+ // Dup over the file again, check that it does not CLOEXEC.
+ nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC));
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/fadvise64.cc b/test/syscalls/linux/fadvise64.cc
index 2af7aa6d9..ac24c4066 100644
--- a/test/syscalls/linux/fadvise64.cc
+++ b/test/syscalls/linux/fadvise64.cc
@@ -45,6 +45,17 @@ TEST(FAdvise64Test, Basic) {
SyscallSucceeds());
}
+TEST(FAdvise64Test, FAdvise64WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL),
+ SyscallFailsWithErrno(EBADF));
+ ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL),
+ SyscallFailsWithErrno(EBADF));
+}
+
TEST(FAdvise64Test, InvalidArgs) {
auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
index edd23e063..5c839447e 100644
--- a/test/syscalls/linux/fallocate.cc
+++ b/test/syscalls/linux/fallocate.cc
@@ -108,6 +108,13 @@ TEST_F(AllocateTest, FallocateReadonly) {
EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF));
}
+TEST_F(AllocateTest, FallocateWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+ EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF));
+}
+
TEST_F(AllocateTest, FallocatePipe) {
int pipes[2];
EXPECT_THAT(pipe(pipes), SyscallSucceeds());
diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc
index 08bcae1e8..c6675802d 100644
--- a/test/syscalls/linux/fchdir.cc
+++ b/test/syscalls/linux/fchdir.cc
@@ -71,6 +71,18 @@ TEST(FchdirTest, NotDir) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
+TEST(FchdirTest, FchdirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_dir.path(), O_PATH));
+ ASSERT_THAT(open(temp_dir.path().c_str(), O_DIRECTORY | O_PATH),
+ SyscallSucceeds());
+
+ EXPECT_THAT(fchdir(fd.get()), SyscallSucceeds());
+ // Change CWD to a permanent location as temp dirs will be cleaned up.
+ EXPECT_THAT(chdir("/"), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 4b581045b..4fa6751ff 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -15,6 +15,7 @@
#include <fcntl.h>
#include <signal.h>
#include <sys/epoll.h>
+#include <sys/mman.h>
#include <sys/types.h>
#include <syscall.h>
#include <unistd.h>
@@ -35,10 +36,12 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/save_util.h"
@@ -204,6 +207,41 @@ PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write,
return std::move(cleanup);
}
+TEST(FcntlTest, FcntlDupWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH));
+
+ int new_fd;
+ // Dup the descriptor and make sure it's the same file.
+ EXPECT_THAT(new_fd = fcntl(fd.get(), F_DUPFD, 0), SyscallSucceeds());
+
+ FileDescriptor nfd = FileDescriptor(new_fd);
+ ASSERT_NE(fd.get(), nfd.get());
+ ASSERT_NO_ERRNO(CheckSameFile(fd, nfd));
+ EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(O_PATH));
+}
+
+TEST(FcntlTest, SetFileStatusFlagWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+
+ EXPECT_THAT(fcntl(fd.get(), F_SETFL, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FcntlTest, BadFcntlsWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+
+ EXPECT_THAT(fcntl(fd.get(), F_SETOWN, 0), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(fcntl(fd.get(), F_GETOWN, 0), SyscallFailsWithErrno(EBADF));
+
+ EXPECT_THAT(fcntl(fd.get(), F_SETOWN_EX, 0), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(fcntl(fd.get(), F_GETOWN_EX, 0), SyscallFailsWithErrno(EBADF));
+}
+
TEST(FcntlTest, SetCloExecBadFD) {
// Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
@@ -223,6 +261,32 @@ TEST(FcntlTest, SetCloExec) {
ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
}
+TEST(FcntlTest, SetCloExecWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ // Open a file descriptor with FD_CLOEXEC descriptor flag not set.
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0));
+
+ // Set the FD_CLOEXEC flag.
+ ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(FcntlTest, DupFDCloExecWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ // Open a file descriptor with FD_CLOEXEC descriptor flag not set.
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+ int nfd;
+ ASSERT_THAT(nfd = fcntl(fd.get(), F_DUPFD_CLOEXEC, 0), SyscallSucceeds());
+ FileDescriptor dup_fd(nfd);
+
+ // Check for the FD_CLOEXEC flag.
+ ASSERT_THAT(fcntl(dup_fd.get(), F_GETFD),
+ SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
TEST(FcntlTest, ClearCloExec) {
// Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set.
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC));
@@ -264,6 +328,22 @@ TEST(FcntlTest, GetAllFlags) {
EXPECT_EQ(rflags, expected);
}
+// When O_PATH is specified in flags, flag bits other than O_CLOEXEC,
+// O_DIRECTORY, and O_NOFOLLOW are ignored.
+TEST(FcntlTest, GetOpathFlag) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ int flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND | O_PATH |
+ O_NOFOLLOW | O_DIRECTORY;
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), flags));
+
+ int expected = O_PATH | O_NOFOLLOW | O_DIRECTORY;
+
+ int rflags;
+ EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds());
+ EXPECT_EQ(rflags, expected);
+}
+
TEST(FcntlTest, SetFlags) {
TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), 0));
@@ -392,6 +472,22 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsRead) {
EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallFailsWithErrno(EBADF));
}
+TEST_F(FcntlLockTest, SetLockWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ struct flock fl0;
+ fl0.l_type = F_WRLCK;
+ fl0.l_whence = SEEK_SET;
+ fl0.l_start = 0;
+ fl0.l_len = 0; // Lock all file
+
+ // Expect that setting a write lock using a Opath file descriptor
+ // won't work.
+ EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallFailsWithErrno(EBADF));
+}
+
TEST_F(FcntlLockTest, SetLockUnlockOnNothing) {
auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd =
@@ -1642,6 +1738,202 @@ TEST(FcntlTest, SetFlSetOwnSetSigDoNotRace) {
}
}
+TEST_F(FcntlLockTest, GetLockOnNothing) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds());
+ ASSERT_TRUE(fl.l_type == F_UNLCK);
+}
+
+TEST_F(FcntlLockTest, GetLockOnLockSameProcess) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds());
+ ASSERT_TRUE(fl.l_type == F_UNLCK);
+
+ fl.l_type = F_WRLCK;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds());
+ ASSERT_TRUE(fl.l_type == F_UNLCK);
+}
+
+TEST_F(FcntlLockTest, GetReadLockOnReadLock) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0);
+ TEST_CHECK(fl.l_type == F_UNLCK);
+ _exit(0);
+ }
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+TEST_F(FcntlLockTest, GetReadLockOnWriteLock) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ fl.l_type = F_RDLCK;
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0);
+ TEST_CHECK(fl.l_type == F_WRLCK);
+ TEST_CHECK(fl.l_pid == getppid());
+ _exit(0);
+ }
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+TEST_F(FcntlLockTest, GetWriteLockOnReadLock) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_RDLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ fl.l_type = F_WRLCK;
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0);
+ TEST_CHECK(fl.l_type == F_RDLCK);
+ TEST_CHECK(fl.l_pid == getppid());
+ _exit(0);
+ }
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+TEST_F(FcntlLockTest, GetWriteLockOnWriteLock) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ pid_t child_pid = fork();
+ if (child_pid == 0) {
+ TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0);
+ TEST_CHECK(fl.l_type == F_WRLCK);
+ TEST_CHECK(fl.l_pid == getppid());
+ _exit(0);
+ }
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Tests that the pid returned from F_GETLK is relative to the caller's PID
+// namespace.
+TEST_F(FcntlLockTest, GetLockRespectsPIDNamespace) {
+ SKIP_IF(IsRunningWithVFS1());
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ std::string filename = file.path();
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDWR, 0666));
+
+ // Lock in the parent process.
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+
+ auto child_getlk = [](void* filename) {
+ int fd = open((char*)filename, O_RDWR, 0666);
+ TEST_CHECK(fd >= 0);
+
+ struct flock fl;
+ fl.l_type = F_WRLCK;
+ fl.l_whence = SEEK_SET;
+ fl.l_start = 0;
+ fl.l_len = 40;
+ TEST_CHECK(fcntl(fd, F_GETLK, &fl) >= 0);
+ TEST_CHECK(fl.l_type == F_WRLCK);
+ // Parent PID should be 0 in the child PID namespace.
+ TEST_CHECK(fl.l_pid == 0);
+ close(fd);
+ return 0;
+ };
+
+ // Set up child process in a new PID namespace.
+ constexpr int kStackSize = 4096;
+ Mapping stack = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kStackSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0));
+ pid_t child_pid;
+ ASSERT_THAT(
+ child_pid = clone(child_getlk, (char*)stack.ptr() + stack.len(),
+ CLONE_NEWPID | SIGCHLD, (void*)filename.c_str()),
+ SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc
index 93c692dd6..2f2b14037 100644
--- a/test/syscalls/linux/getdents.cc
+++ b/test/syscalls/linux/getdents.cc
@@ -429,6 +429,32 @@ TYPED_TEST(GetdentsTest, NotDir) {
SyscallFailsWithErrno(ENOTDIR));
}
+// Test that getdents returns EBADF when called on an opath file.
+TYPED_TEST(GetdentsTest, OpathFile) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ typename TestFixture::DirentBufferType dirents(256);
+ EXPECT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(),
+ dirents.Size()),
+ SyscallFailsWithErrno(EBADF));
+}
+
+// Test that getdents returns EBADF when called on an opath directory.
+TYPED_TEST(GetdentsTest, OpathDirectory) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_PATH | O_DIRECTORY));
+
+ typename TestFixture::DirentBufferType dirents(256);
+ ASSERT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(),
+ dirents.Size()),
+ SyscallFailsWithErrno(EBADF));
+}
+
// Test that SEEK_SET to 0 causes getdents to re-read the entries.
TYPED_TEST(GetdentsTest, SeekResetsCursor) {
// . and .. should be in an otherwise empty directory.
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index 8137f0e29..a88c89e20 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -36,6 +36,7 @@
#include "test/util/epoll_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -315,8 +316,7 @@ PosixErrorOr<std::vector<Event>> DrainEvents(int fd) {
}
PosixErrorOr<FileDescriptor> InotifyInit1(int flags) {
- int fd;
- EXPECT_THAT(fd = inotify_init1(flags), SyscallSucceeds());
+ int fd = inotify_init1(flags);
if (fd < 0) {
return PosixError(errno, "inotify_init1() failed");
}
@@ -325,9 +325,7 @@ PosixErrorOr<FileDescriptor> InotifyInit1(int flags) {
PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path,
uint32_t mask) {
- int wd;
- EXPECT_THAT(wd = inotify_add_watch(fd, path.c_str(), mask),
- SyscallSucceeds());
+ int wd = inotify_add_watch(fd, path.c_str(), mask);
if (wd < 0) {
return PosixError(errno, "inotify_add_watch() failed");
}
@@ -784,6 +782,38 @@ TEST(Inotify, MoveWatchedTargetGeneratesEvents) {
EXPECT_EQ(events[0].cookie, events[1].cookie);
}
+// Tests that close events are only emitted when a file description drops its
+// last reference.
+TEST(Inotify, DupFD) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup());
+
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_OPEN, wd),
+ }));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ fd2.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_CLOSE_NOWRITE, wd),
+ }));
+}
+
TEST(Inotify, CoalesceEvents) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const FileDescriptor fd =
@@ -1779,11 +1809,9 @@ TEST(Inotify, Sendfile) {
EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)}));
}
-// On Linux, inotify behavior is not very consistent with splice(2). We try our
-// best to emulate Linux for very basic calls to splice.
TEST(Inotify, SpliceOnWatchTarget) {
- int pipes[2];
- ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+ int pipefds[2];
+ ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds());
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const FileDescriptor inotify_fd =
@@ -1798,15 +1826,20 @@ TEST(Inotify, SpliceOnWatchTarget) {
const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
- EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, 1, /*flags=*/0),
+ EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr, 1, /*flags=*/0),
SyscallSucceedsWithValue(1));
- // Surprisingly, events are not generated in Linux if we read from a file.
+ // Surprisingly, events may not be generated in Linux if we read from a file.
+ // fs/splice.c:generic_file_splice_read, which is used most often, does not
+ // generate events, whereas fs/splice.c:default_file_splice_read does.
std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
- ASSERT_THAT(events, Are({}));
+ if (IsRunningOnGvisor() && !IsRunningWithVFS1()) {
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, dir_wd, Basename(file.path())),
+ Event(IN_ACCESS, file_wd)}));
+ }
- EXPECT_THAT(splice(pipes[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0),
+ EXPECT_THAT(splice(pipefds[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0),
SyscallSucceedsWithValue(1));
events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
@@ -1817,8 +1850,8 @@ TEST(Inotify, SpliceOnWatchTarget) {
}
TEST(Inotify, SpliceOnInotifyFD) {
- int pipes[2];
- ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+ int pipefds[2];
+ ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds());
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const FileDescriptor fd =
@@ -1834,11 +1867,11 @@ TEST(Inotify, SpliceOnInotifyFD) {
char buf;
EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
- EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr,
+ EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr,
sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
SyscallSucceedsWithValue(sizeof(struct inotify_event)));
- const FileDescriptor read_fd(pipes[0]);
+ const FileDescriptor read_fd(pipefds[0]);
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
@@ -1936,24 +1969,29 @@ TEST(Inotify, Xattr) {
}
TEST(Inotify, Exec) {
- const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- const TempPath bin = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo(dir.path(), "/bin/true"));
-
+ SKIP_IF(IsRunningWithVFS1());
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
const int wd = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), bin.path(), IN_ALL_EVENTS));
+ InotifyAddWatch(fd.get(), "/bin/true", IN_ALL_EVENTS));
// Perform exec.
- ScopedThread t([&bin]() {
- ASSERT_THAT(execl(bin.path().c_str(), bin.path().c_str(), (char*)nullptr),
- SyscallSucceeds());
- });
- t.Join();
+ pid_t child = -1;
+ int execve_errno = -1;
+ auto kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec("/bin/true", {}, {}, nullptr, &child, &execve_errno));
+ ASSERT_EQ(0, execve_errno);
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(0, status);
+
+ // Process cleanup no longer needed.
+ kill.Release();
std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
- EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd)}));
+ EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd),
+ Event(IN_CLOSE_NOWRITE, wd)}));
}
// Watches without IN_EXCL_UNLINK, should continue to emit events for file
diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc
index b0a07a064..9b16d1558 100644
--- a/test/syscalls/linux/ioctl.cc
+++ b/test/syscalls/linux/ioctl.cc
@@ -76,6 +76,19 @@ TEST_F(IoctlTest, InvalidControlNumber) {
EXPECT_THAT(ioctl(STDOUT_FILENO, 0), SyscallFailsWithErrno(ENOTTY));
}
+TEST_F(IoctlTest, IoctlWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_PATH));
+
+ int set = 1;
+ EXPECT_THAT(ioctl(fd.get(), FIONBIO, &set), SyscallFailsWithErrno(EBADF));
+
+ EXPECT_THAT(ioctl(fd.get(), FIONCLEX), SyscallFailsWithErrno(EBADF));
+
+ EXPECT_THAT(ioctl(fd.get(), FIOCLEX), SyscallFailsWithErrno(EBADF));
+}
+
TEST_F(IoctlTest, FIONBIOSucceeds) {
EXPECT_FALSE(CheckNonBlocking(fd()));
int set = 1;
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
index 544681168..4f9ca1a65 100644
--- a/test/syscalls/linux/link.cc
+++ b/test/syscalls/linux/link.cc
@@ -50,6 +50,8 @@ bool IsSameFile(const std::string& f1, const std::string& f2) {
return stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino;
}
+// TODO(b/178640646): Add test for linkat with AT_EMPTY_PATH
+
TEST(LinkTest, CanCreateLinkFile) {
auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const std::string newname = NewTempAbsPath();
@@ -235,6 +237,59 @@ TEST(LinkTest, AbsPathsWithNonDirFDs) {
SyscallSucceeds());
}
+TEST(LinkTest, NewDirFDWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newname_parent = NewTempAbsPath();
+ const std::string newname_base = "child";
+ const std::string newname = JoinPath(newname_parent, newname_base);
+
+ // Create newname_parent directory, and get an FD.
+ EXPECT_THAT(mkdir(newname_parent.c_str(), 0777), SyscallSucceeds());
+ const FileDescriptor newname_parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(newname_parent, O_DIRECTORY | O_PATH));
+
+ // Link newname to oldfile, using newname_parent_fd.
+ EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), newname_parent_fd.get(),
+ newname.c_str(), 0),
+ SyscallSucceeds());
+
+ EXPECT_TRUE(IsSameFile(oldfile.path(), newname));
+}
+
+TEST(LinkTest, RelPathsNonDirFDsWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ // Create a file that will be passed as the directory fd for old/new names.
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+
+ // Using file_fd as olddirfd will fail.
+ EXPECT_THAT(linkat(file_fd.get(), "foo", AT_FDCWD, "bar", 0),
+ SyscallFailsWithErrno(ENOTDIR));
+
+ // Using file_fd as newdirfd will fail.
+ EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), file_fd.get(), "bar", 0),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+TEST(LinkTest, AbsPathsNonDirFDsWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string newname = NewTempAbsPath();
+
+ // Create a file that will be passed as the directory fd for old/new names.
+ TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH));
+
+ // Using file_fd as the dirfds is OK as long as paths are absolute.
+ EXPECT_THAT(linkat(file_fd.get(), oldfile.path().c_str(), file_fd.get(),
+ newname.c_str(), 0),
+ SyscallSucceeds());
+}
+
TEST(LinkTest, LinkDoesNotFollowSymlinks) {
// Create oldfile, and oldsymlink which points to it.
auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc
index 5a1973f60..6e714b12c 100644
--- a/test/syscalls/linux/madvise.cc
+++ b/test/syscalls/linux/madvise.cc
@@ -179,9 +179,9 @@ TEST(MadviseDontforkTest, DontforkShared) {
// First page is mapped in child and modifications are visible to parent
// via the shared mapping.
TEST_CHECK(IsMapped(ms1.addr()));
- ExpectAllMappingBytes(ms1, 2);
+ CheckAllMappingBytes(ms1, 2);
memset(ms1.ptr(), 1, kPageSize);
- ExpectAllMappingBytes(ms1, 1);
+ CheckAllMappingBytes(ms1, 1);
// Second page must not be mapped in child.
TEST_CHECK(!IsMapped(ms2.addr()));
@@ -222,9 +222,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) {
// page. The mapping is private so the modifications are not visible to
// the parent.
TEST_CHECK(IsMapped(mp1.addr()));
- ExpectAllMappingBytes(mp1, 1);
+ CheckAllMappingBytes(mp1, 1);
memset(mp1.ptr(), 11, kPageSize);
- ExpectAllMappingBytes(mp1, 11);
+ CheckAllMappingBytes(mp1, 11);
// Verify second page is not mapped.
TEST_CHECK(!IsMapped(mp2.addr()));
@@ -233,9 +233,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) {
// page. The mapping is private so the modifications are not visible to
// the parent.
TEST_CHECK(IsMapped(mp3.addr()));
- ExpectAllMappingBytes(mp3, 3);
+ CheckAllMappingBytes(mp3, 3);
memset(mp3.ptr(), 13, kPageSize);
- ExpectAllMappingBytes(mp3, 13);
+ CheckAllMappingBytes(mp3, 13);
};
EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index 83546830d..93a6d9cde 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -930,6 +930,18 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) {
SyscallFailsWithErrno(EACCES));
}
+// Mmap not allowed on O_PATH FDs.
+TEST_F(MMapFileTest, MmapFileWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ uintptr_t addr;
+ EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0),
+ SyscallFailsWithErrno(EBADF));
+}
+
// The FD must be readable.
TEST_P(MMapFileParamTest, WriteOnlyFd) {
const FileDescriptor fd =
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index fcd162ca2..733b17834 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -45,7 +45,7 @@ namespace {
// * O_CREAT
// * O_DIRECTORY
// * O_NOFOLLOW
-// * O_PATH <- Will we ever support this?
+// * O_PATH
//
// Special operations on open:
// * O_EXCL
@@ -517,6 +517,26 @@ TEST_F(OpenTest, OpenWithStrangeFlags) {
EXPECT_THAT(read(fd.get(), &c, 1), SyscallFailsWithErrno(EBADF));
}
+TEST_F(OpenTest, OpenWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ const DisableSave ds; // Permissions are dropped.
+ std::string path = NewTempAbsPath();
+
+ // Create a file without user permissions.
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055));
+
+ // Cannot open file as read only because we are owner and have no permissions
+ // set.
+ EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // Can open file with O_PATH because don't need permissions on the object when
+ // opening with O_PATH.
+ ASSERT_NO_ERRNO(Open(path, O_PATH));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 2ed4f6f9c..d25be0e30 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -548,13 +548,7 @@ TEST_P(RawPacketTest, SetSocketSendBuf) {
ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
SyscallSucceeds());
- // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
- // matches linux behavior.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
-
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
index a9bfdb37b..999c8ab6b 100644
--- a/test/syscalls/linux/ping_socket.cc
+++ b/test/syscalls/linux/ping_socket.cc
@@ -31,51 +31,36 @@ namespace gvisor {
namespace testing {
namespace {
-class PingSocket : public ::testing::Test {
- protected:
- // Creates a socket to be used in tests.
- void SetUp() override;
-
- // Closes the socket created by SetUp().
- void TearDown() override;
-
- // The loopback address.
- struct sockaddr_in addr_;
-};
-
-void PingSocket::SetUp() {
- // On some hosts ping sockets are restricted to specific groups using the
- // sysctl "ping_group_range".
- int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
- if (s < 0 && errno == EPERM) {
- GTEST_SKIP();
- }
- close(s);
-
- addr_ = {};
- // Just a random port as the destination port number is irrelevant for ping
- // sockets.
- addr_.sin_port = 12345;
- addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- addr_.sin_family = AF_INET;
-}
-
-void PingSocket::TearDown() {}
-
// Test ICMP port exhaustion returns EAGAIN.
//
// We disable both random/cooperative S/R for this test as it makes way too many
// syscalls.
-TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+TEST(PingSocket, ICMPPortExhaustion_NoRandomSave) {
DisableSave ds;
+
+ {
+ auto s = Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP);
+ if (!s.ok()) {
+ ASSERT_EQ(s.error().errno_value(), EACCES);
+ GTEST_SKIP();
+ }
+ }
+
+ const struct sockaddr_in addr = {
+ .sin_family = AF_INET,
+ .sin_addr =
+ {
+ .s_addr = htonl(INADDR_LOOPBACK),
+ },
+ };
+
std::vector<FileDescriptor> sockets;
constexpr int kSockets = 65536;
- addr_.sin_port = 0;
for (int i = 0; i < kSockets; i++) {
auto s =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
- int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_),
- sizeof(addr_));
+ int ret = connect(s.get(), reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr));
if (ret == 0) {
sockets.push_back(std::move(s));
continue;
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
index bcdbbb044..c74990ba1 100644
--- a/test/syscalls/linux/pread64.cc
+++ b/test/syscalls/linux/pread64.cc
@@ -77,6 +77,16 @@ TEST_F(Pread64Test, WriteOnlyNotReadable) {
EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF));
}
+TEST_F(Pread64Test, Pread64WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ char buf[1024];
+ EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF));
+}
+
TEST_F(Pread64Test, DirNotReadable) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY));
diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc
index 5b0743fe9..1c40f0915 100644
--- a/test/syscalls/linux/preadv.cc
+++ b/test/syscalls/linux/preadv.cc
@@ -89,6 +89,20 @@ TEST(PreadvTest, MMConcurrencyStress) {
// The test passes if it neither deadlocks nor crashes the OS.
}
+// This test calls preadv with an O_PATH fd.
+TEST(PreadvTest, PreadvWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ struct iovec iov;
+ iov.iov_base = nullptr;
+ iov.iov_len = 0;
+
+ EXPECT_THAT(preadv(fd.get(), &iov, 1, 0), SyscallFailsWithErrno(EBADF));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
index 4a9acd7ae..cb58719c4 100644
--- a/test/syscalls/linux/preadv2.cc
+++ b/test/syscalls/linux/preadv2.cc
@@ -226,6 +226,24 @@ TEST(Preadv2Test, TestUnreadableFile) {
SyscallFailsWithErrno(EBADF));
}
+// This test calls preadv2 with a file opened with O_PATH.
+TEST(Preadv2Test, Preadv2WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ auto iov = absl::make_unique<struct iovec[]>(1);
+ iov[0].iov_base = nullptr;
+ iov[0].iov_len = 0;
+
+ EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/0,
+ /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+}
+
// Calling preadv2 with a non-negative offset calls preadv. Calling preadv with
// an unseekable file is not allowed. A pipe is used for an unseekable file.
TEST(Preadv2Test, TestUnseekableFileInvalid) {
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index 0b174e2be..85ff258df 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1338,6 +1338,7 @@ TEST_F(JobControlTest, SetTTYDifferentSession) {
TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild);
TEST_PCHECK(gcwstatus == 0);
});
+ ASSERT_NO_ERRNO(res);
}
TEST_F(JobControlTest, ReleaseTTY) {
@@ -1515,7 +1516,8 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) {
// - creates a child process in a new process group
// - sets that child as the foreground process group
// - kills its child and sets itself as the foreground process group.
-TEST_F(JobControlTest, SetForegroundProcessGroup) {
+// TODO(gvisor.dev/issue/5357): Fix and enable.
+TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroup) {
auto res = RunInChild([=]() {
TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0));
@@ -1557,6 +1559,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) {
TEST_PCHECK(pgid = getpgid(0) == 0);
TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid));
});
+ ASSERT_NO_ERRNO(res);
}
TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) {
@@ -1576,8 +1579,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) {
ASSERT_NO_ERRNO(ret);
}
-TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
- auto ret = RunInChild([=]() {
+// TODO(gvisor.dev/issue/5357): Fix and enable.
+TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroupEmptyProcessGroup) {
+ auto res = RunInChild([=]() {
TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0));
// Create a new process, put it in a new process group, make that group the
@@ -1595,6 +1599,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 &&
errno == ESRCH);
});
+ ASSERT_NO_ERRNO(res);
}
TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc
index e69794910..1b2f25363 100644
--- a/test/syscalls/linux/pwrite64.cc
+++ b/test/syscalls/linux/pwrite64.cc
@@ -77,6 +77,17 @@ TEST_F(Pwrite64, Overflow) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
+TEST_F(Pwrite64, Pwrite64WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ std::vector<char> buf(1);
+ EXPECT_THAT(PwriteFd(fd.get(), buf.data(), 1, 0),
+ SyscallFailsWithErrno(EBADF));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
index 63b686c62..00aed61b4 100644
--- a/test/syscalls/linux/pwritev2.cc
+++ b/test/syscalls/linux/pwritev2.cc
@@ -283,6 +283,23 @@ TEST(Pwritev2Test, ReadOnlyFile) {
SyscallFailsWithErrno(EBADF));
}
+TEST(Pwritev2Test, Pwritev2WithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+}
+
// This test calls pwritev2 with an invalid flag.
TEST(Pwritev2Test, InvalidFlag) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 955bcee4b..32924466f 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -621,13 +621,7 @@ TEST_P(RawSocketTest, SetSocketSendBuf) {
ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
SyscallSucceeds());
- // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
- // matches linux behavior.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
-
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc
index 2633ba31b..98d5e432d 100644
--- a/test/syscalls/linux/read.cc
+++ b/test/syscalls/linux/read.cc
@@ -112,6 +112,15 @@ TEST_F(ReadTest, ReadDirectoryFails) {
EXPECT_THAT(ReadFd(file.get(), buf.data(), 1), SyscallFailsWithErrno(EISDIR));
}
+TEST_F(ReadTest, ReadWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+ std::vector<char> buf(1);
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallFailsWithErrno(EBADF));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
index baaf9f757..86808d255 100644
--- a/test/syscalls/linux/readv.cc
+++ b/test/syscalls/linux/readv.cc
@@ -251,6 +251,20 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) {
SyscallFailsWithErrno(EFAULT));
}
+TEST_F(ReadvTest, ReadvWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ char buffer[1024];
+ struct iovec iov[1];
+ iov[0].iov_base = buffer;
+ iov[0].iov_len = 1024;
+
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH));
+
+ ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EBADF));
+}
+
// This test depends on the maximum extent of a single readv() syscall, so
// we can't tolerate interruption from saving.
TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc
index d6e8b3e59..baf794152 100644
--- a/test/syscalls/linux/shm.cc
+++ b/test/syscalls/linux/shm.cc
@@ -256,32 +256,26 @@ TEST(ShmTest, IpcInfo) {
}
TEST(ShmTest, ShmInfo) {
- struct shm_info info;
-
- // We generally can't know what other processes on a linux machine
- // does with shared memory segments, so we can't test specific
- // numbers on Linux. When running under gvisor, we're guaranteed to
- // be the only ones using shm, so we can easily verify machine-wide
- // numbers.
- if (IsRunningOnGvisor()) {
- ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &info));
- EXPECT_EQ(info.used_ids, 0);
- EXPECT_EQ(info.shm_tot, 0);
- EXPECT_EQ(info.shm_rss, 0);
- EXPECT_EQ(info.shm_swp, 0);
- }
+ // Take a snapshot of the system before the test runs.
+ struct shm_info snap;
+ ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &snap));
const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ struct shm_info info;
ASSERT_NO_ERRNO(Shmctl(1, SHM_INFO, &info));
+ // We generally can't know what other processes on a linux machine do with
+ // shared memory segments, so we can't test specific numbers on Linux. When
+ // running under gvisor, we're guaranteed to be the only ones using shm, so
+ // we can easily verify machine-wide numbers.
if (IsRunningOnGvisor()) {
ASSERT_NO_ERRNO(Shmctl(shm.id(), SHM_INFO, &info));
- EXPECT_EQ(info.used_ids, 1);
- EXPECT_EQ(info.shm_tot, kAllocSize / kPageSize);
- EXPECT_EQ(info.shm_rss, kAllocSize / kPageSize);
+ EXPECT_EQ(info.used_ids, snap.used_ids + 1);
+ EXPECT_EQ(info.shm_tot, snap.shm_tot + (kAllocSize / kPageSize));
+ EXPECT_EQ(info.shm_rss, snap.shm_rss + (kAllocSize / kPageSize));
EXPECT_EQ(info.shm_swp, 0); // Gvisor currently never swaps.
}
@@ -378,18 +372,18 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) {
SetupGvisorDeathTest();
const auto rest = [&] {
- ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE(
+ ShmSegment shm = TEST_CHECK_NO_ERRNO_AND_VALUE(
Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777));
- char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
+ char* addr = TEST_CHECK_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0));
// Mark the segment as destroyed so it's automatically cleaned up when we
// crash below. We can't rely on the standard cleanup since the destructor
// will not run after the SIGSEGV. Note that this doesn't destroy the
// segment immediately since we're still attached to it.
- ASSERT_NO_ERRNO(shm.Rmid());
+ TEST_CHECK_NO_ERRNO(shm.Rmid());
addr[0] = 'x';
- ASSERT_NO_ERRNO(Shmdt(addr));
+ TEST_CHECK_NO_ERRNO(Shmdt(addr));
// This access should cause a SIGSEGV.
addr[0] = 'x';
diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigreturn_amd64.cc
index 6227774a4..6227774a4 100644
--- a/test/syscalls/linux/sigiret.cc
+++ b/test/syscalls/linux/sigreturn_amd64.cc
diff --git a/test/syscalls/linux/sigreturn_arm64.cc b/test/syscalls/linux/sigreturn_arm64.cc
new file mode 100644
index 000000000..2c19e2984
--- /dev/null
+++ b/test/syscalls/linux/sigreturn_arm64.cc
@@ -0,0 +1,97 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/unistd.h>
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/ucontext.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/logging.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/timer_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr uint64_t kOrigX7 = 0xdeadbeeffacefeed;
+
+void sigvtalrm(int sig, siginfo_t* siginfo, void* _uc) {
+ ucontext_t* uc = reinterpret_cast<ucontext_t*>(_uc);
+
+ // Verify that:
+ // - x7 value in mcontext_t matches kOrigX7.
+ if (uc->uc_mcontext.regs[7] == kOrigX7) {
+ // Modify the value x7 in the ucontext. This is the value seen by the
+ // application after the signal handler returns.
+ uc->uc_mcontext.regs[7] = ~kOrigX7;
+ }
+}
+
+int testX7(uint64_t* val, uint64_t sysno, uint64_t tgid, uint64_t tid,
+ uint64_t signo) {
+ register uint64_t* x9 __asm__("x9") = val;
+ register uint64_t x8 __asm__("x8") = sysno;
+ register uint64_t x0 __asm__("x0") = tgid;
+ register uint64_t x1 __asm__("x1") = tid;
+ register uint64_t x2 __asm__("x2") = signo;
+
+ // Initialize x7, send SIGVTALRM to itself and read x7.
+ __asm__(
+ "ldr x7, [x9, 0]\n"
+ "svc 0\n"
+ "str x7, [x9, 0]\n"
+ : "=r"(x0)
+ : "r"(x0), "r"(x1), "r"(x2), "r"(x9), "r"(x8)
+ : "x7");
+ return x0;
+}
+
+// On ARM64, when ptrace stops on a system call, it uses the x7 register to
+// indicate whether the stop has been signalled from syscall entry or syscall
+// exit. This means that we can't get a value of this register and we can't
+// change it. More details are in the comment for tracehook_report_syscall in
+// arch/arm64/kernel/ptrace.c.
+//
+// CheckR7 checks that the ptrace platform handles the x7 register properly.
+TEST(SigreturnTest, CheckX7) {
+ // Setup signal handler for SIGVTALRM.
+ struct sigaction sa = {};
+ sigfillset(&sa.sa_mask);
+ sa.sa_sigaction = sigvtalrm;
+ sa.sa_flags = SA_SIGINFO;
+ auto const action_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGVTALRM, sa));
+
+ auto const mask_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGVTALRM));
+
+ uint64_t x7 = kOrigX7;
+
+ testX7(&x7, __NR_tgkill, getpid(), syscall(__NR_gettid), SIGVTALRM);
+
+ // The following check verifies that %x7 was not clobbered
+ // when returning from the signal handler (via sigreturn(2)).
+ EXPECT_EQ(x7, ~kOrigX7);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index 32f583581..b616c2c87 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -16,6 +16,7 @@
#include <sys/stat.h>
#include <sys/statfs.h>
#include <sys/types.h>
+#include <sys/wait.h>
#include <unistd.h>
#include "gtest/gtest.h"
@@ -90,8 +91,7 @@ TEST(SocketTest, UnixSocketStat) {
EXPECT_EQ(statbuf.st_mode, S_IFSOCK | (sock_perm & ~mask));
// Timestamps should be equal and non-zero.
- // TODO(b/158882152): Sockets currently don't implement timestamps.
- if (!IsRunningOnGvisor()) {
+ if (!IsRunningWithVFS1()) {
EXPECT_NE(statbuf.st_atime, 0);
EXPECT_EQ(statbuf.st_atime, statbuf.st_mtime);
EXPECT_EQ(statbuf.st_atime, statbuf.st_ctime);
@@ -111,6 +111,77 @@ TEST(SocketTest, UnixSocketStatFS) {
EXPECT_EQ(st.f_namelen, NAME_MAX);
}
+TEST(SocketTest, UnixSCMRightsOnlyPassedOnce_NoRandomSave) {
+ const DisableSave ds;
+
+ int sockets[2];
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds());
+ // Send more than what will fit inside the send/receive buffers, so that it is
+ // split into multiple messages.
+ constexpr int kBufSize = 0x100000;
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ TEST_PCHECK(close(sockets[0]) == 0);
+
+ // Construct a message with some control message.
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int))] = {};
+ std::vector<char> buf(kBufSize);
+ struct iovec iov = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ ((int*)CMSG_DATA(cmsg))[0] = sockets[1];
+
+ iov.iov_base = buf.data();
+ iov.iov_len = kBufSize;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ int n = sendmsg(sockets[1], &msg, 0);
+ TEST_PCHECK(n == kBufSize);
+ TEST_PCHECK(shutdown(sockets[1], SHUT_RDWR) == 0);
+ TEST_PCHECK(close(sockets[1]) == 0);
+ _exit(0);
+ }
+
+ close(sockets[1]);
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(int))] = {};
+ std::vector<char> buf(kBufSize);
+ struct iovec iov = {};
+ msg.msg_control = &control;
+ msg.msg_controllen = sizeof(control);
+
+ iov.iov_base = buf.data();
+ iov.iov_len = kBufSize;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // The control message should only be present in the first message received.
+ int n;
+ ASSERT_THAT(n = recvmsg(sockets[0], &msg, 0), SyscallSucceeds());
+ ASSERT_GT(n, 0);
+ ASSERT_EQ(msg.msg_controllen, CMSG_SPACE(sizeof(int)));
+
+ while (n > 0) {
+ ASSERT_THAT(n = recvmsg(sockets[0], &msg, 0), SyscallSucceeds());
+ ASSERT_EQ(msg.msg_controllen, 0);
+ }
+
+ close(sockets[0]);
+
+ int status;
+ ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
+ ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
using SocketOpenTest = ::testing::TestWithParam<int>;
// UDS cannot be opened.
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 831d96262..a73987a7e 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -65,6 +65,33 @@ TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) {
SyscallSucceeds());
}
+TEST_P(TCPSocketPairTest, CheckTcpInfoFields) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until second_fd sees the data and then recv it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0};
+ constexpr int kPollTimeoutMs = 2000; // Wait up to 2 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct tcp_info opt = {};
+ socklen_t optLen = sizeof(opt);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen),
+ SyscallSucceeds());
+
+ // Validates the received tcp_info fields.
+ EXPECT_EQ(opt.tcpi_ca_state, 0);
+ EXPECT_GT(opt.tcpi_snd_cwnd, 0);
+ EXPECT_GT(opt.tcpi_rto, 0);
+}
+
// This test validates that an RST is sent instead of a FIN when data is
// unread on calls to close(2).
TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadData) {
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index e557572a7..8eec31a46 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -2513,11 +2513,7 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) {
ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
SyscallSucceeds());
- // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
-
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc
index a16899493..22a4ee0d1 100644
--- a/test/syscalls/linux/socket_unix_cmsg.cc
+++ b/test/syscalls/linux/socket_unix_cmsg.cc
@@ -362,7 +362,7 @@ TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) {
// BasicFDPassUnalignedRecv starts off by sending a single FD just like
// BasicFDPass. The difference is that when calling recvmsg, the length of the
-// receive data is only aligned on a 4 byte boundry instead of the normal 8.
+// receive data is only aligned on a 4 byte boundary instead of the normal 8.
TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 6e7142a42..72f888659 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -221,6 +221,43 @@ TEST_F(StatTest, TrailingSlashNotCleanedReturnsENOTDIR) {
EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOTDIR));
}
+TEST_F(StatTest, FstatFileWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ struct stat st;
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH));
+
+ // Stat the directory.
+ ASSERT_THAT(fstat(fd.get(), &st), SyscallSucceeds());
+}
+
+TEST_F(StatTest, FstatDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ struct stat st;
+ TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY));
+
+ // Stat the directory.
+ ASSERT_THAT(fstat(dirfd.get(), &st), SyscallSucceeds());
+}
+
+// fstatat with an O_PATH fd
+TEST_F(StatTest, FstatatDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY));
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ struct stat st = {};
+ EXPECT_THAT(fstatat(dirfd.get(), tmpfile.path().c_str(), &st, 0),
+ SyscallSucceeds());
+ EXPECT_FALSE(S_ISDIR(st.st_mode));
+ EXPECT_TRUE(S_ISREG(st.st_mode));
+}
+
// Test fstatating a symlink directory.
TEST_F(StatTest, FstatatSymlinkDir) {
// Create a directory and symlink to it.
diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc
index f0fb166bd..d4ea8e026 100644
--- a/test/syscalls/linux/statfs.cc
+++ b/test/syscalls/linux/statfs.cc
@@ -64,6 +64,16 @@ TEST(FstatfsTest, InternalTmpfs) {
EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds());
}
+TEST(FstatfsTest, CanStatFileWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH));
+
+ struct statfs st;
+ EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds());
+}
+
TEST(FstatfsTest, InternalDevShm) {
auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const FileDescriptor fd =
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index 4d9eba7f0..ea219a091 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -269,6 +269,36 @@ TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) {
EXPECT_THAT(close(dirfd), SyscallSucceeds());
}
+TEST(SymlinkTest, SymlinkAtDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string filepath = NewTempAbsPathInDir(dir.path());
+ const std::string base = std::string(Basename(filepath));
+ FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH));
+
+ EXPECT_THAT(symlinkat("/dangling", dirfd.get(), base.c_str()),
+ SyscallSucceeds());
+}
+
+TEST(SymlinkTest, ReadlinkAtDirWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string filepath = NewTempAbsPathInDir(dir.path());
+ const std::string base = std::string(Basename(filepath));
+ ASSERT_THAT(symlink("/dangling", filepath.c_str()), SyscallSucceeds());
+
+ FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH));
+
+ std::vector<char> buf(1024);
+ int linksize;
+ EXPECT_THAT(
+ linksize = readlinkat(dirfd.get(), base.c_str(), buf.data(), 1024),
+ SyscallSucceeds());
+ EXPECT_EQ(0, strncmp("/dangling", buf.data(), linksize));
+}
+
TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc
index 8aa2525a9..84a2c4ed7 100644
--- a/test/syscalls/linux/sync.cc
+++ b/test/syscalls/linux/sync.cc
@@ -49,10 +49,20 @@ TEST(SyncTest, SyncFromPipe) {
EXPECT_THAT(close(pipes[1]), SyscallSucceeds());
}
-TEST(SyncTest, CannotSyncFileSytemAtBadFd) {
+TEST(SyncTest, CannotSyncFileSystemAtBadFd) {
EXPECT_THAT(syncfs(-1), SyscallFailsWithErrno(EBADF));
}
+TEST(SyncTest, CannotSyncFileSystemAtOpathFD) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH));
+
+ EXPECT_THAT(syncfs(fd.get()), SyscallFailsWithErrno(EBADF));
+}
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 714848b8e..9028ab024 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -2008,6 +2008,29 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
EXPECT_EQ(got, 0);
}
+// Tests that connecting to an unspecified address results in ECONNREFUSED.
+TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ memset(&addr, 0, addrlen);
+ addr.ss_family = GetParam();
+ auto do_connect = [&addr, addrlen]() {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(addr.ss_family, SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ };
+ do_connect();
+ // Test the v4 mapped address as well.
+ if (GetParam() == AF_INET6) {
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_addr.s6_addr[10] = sin6->sin6_addr.s6_addr[11] = 0xff;
+ do_connect();
+ }
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc
index bfc95ed38..17832c47d 100644
--- a/test/syscalls/linux/truncate.cc
+++ b/test/syscalls/linux/truncate.cc
@@ -196,6 +196,16 @@ TEST(TruncateTest, FtruncateNonWriteable) {
EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
}
+TEST(TruncateTest, FtruncateWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH));
+ EXPECT_THAT(ftruncate(fd.get(), 0), AnyOf(SyscallFailsWithErrno(EBADF),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
// ftruncate(2) should succeed as long as the file descriptor is writeable,
// regardless of whether the file permissions allow writing.
TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) {
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
index 77bcfbb8a..740992d0a 100644
--- a/test/syscalls/linux/write.cc
+++ b/test/syscalls/linux/write.cc
@@ -218,6 +218,44 @@ TEST_F(WriteTest, PwriteNoChangeOffset) {
EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total));
}
+TEST_F(WriteTest, WriteWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor f =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH));
+ int fd = f.get();
+
+ EXPECT_THAT(WriteBytes(fd, 1024), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(WriteTest, WritevWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor f =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH));
+ int fd = f.get();
+
+ char buf[16];
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+
+ EXPECT_THAT(writev(fd, &iov, /*__count=*/1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST_F(WriteTest, PwriteWithOpath) {
+ SKIP_IF(IsRunningWithVFS1());
+ TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor f =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH));
+ int fd = f.get();
+
+ const std::string data = "hello world\n";
+
+ EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0),
+ SyscallFailsWithErrno(EBADF));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index bd3f829c4..a953a55fe 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -607,6 +607,27 @@ TEST_F(XattrTest, XattrWithFD) {
EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
}
+TEST_F(XattrTest, XattrWithOPath) {
+ SKIP_IF(IsRunningWithVFS1());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), O_PATH));
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0),
+ SyscallFailsWithErrno(EBADF));
+
+ int buf;
+ EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size),
+ SyscallFailsWithErrno(EBADF));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)),
+ SyscallFailsWithErrno(EBADF));
+
+ EXPECT_THAT(fremovexattr(fd.get(), name), SyscallFailsWithErrno(EBADF));
+}
+
TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) {
// Trusted namespace not supported in VFS1.
SKIP_IF(IsRunningWithVFS1());
diff --git a/test/util/BUILD b/test/util/BUILD
index 1b028a477..e561f3daa 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -172,6 +172,7 @@ cc_library(
":posix_error",
":save_util",
":test_util",
+ gtest,
"@com_google_absl//absl/strings",
],
)
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index b16055dd8..5f1ce0d8a 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -663,5 +663,21 @@ PosixErrorOr<bool> IsOverlayfs(const std::string& path) {
return stat.f_type == OVERLAYFS_SUPER_MAGIC;
}
+PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) {
+ struct stat stat_result1, stat_result2;
+ int res = fstat(fd1.get(), &stat_result1);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("fstat ", fd1.get()));
+ }
+
+ res = fstat(fd2.get(), &stat_result2);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("fstat ", fd2.get()));
+ }
+ EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev);
+ EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino);
+
+ return NoError();
+}
} // namespace testing
} // namespace gvisor
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index c99cf5eb7..2190c3bca 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -191,6 +191,8 @@ PosixErrorOr<bool> IsTmpfs(const std::string& path);
// IsOverlayfs returns true if the file at path is backed by overlayfs.
PosixErrorOr<bool> IsOverlayfs(const std::string& path);
+PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2);
+
namespace internal {
// Not part of the public API.
std::string JoinPathImpl(std::initializer_list<absl::string_view> paths);
diff --git a/test/util/logging.cc b/test/util/logging.cc
index 5d5e76c46..5fadb076b 100644
--- a/test/util/logging.cc
+++ b/test/util/logging.cc
@@ -69,9 +69,7 @@ int WriteNumber(int fd, uint32_t val) {
} // namespace
void CheckFailure(const char* cond, size_t cond_size, const char* msg,
- size_t msg_size, bool include_errno) {
- int saved_errno = errno;
-
+ size_t msg_size, int errno_value) {
constexpr char kCheckFailure[] = "Check failed: ";
Write(2, kCheckFailure, sizeof(kCheckFailure) - 1);
Write(2, cond, cond_size);
@@ -81,10 +79,10 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg,
Write(2, msg, msg_size);
}
- if (include_errno) {
+ if (errno_value != 0) {
constexpr char kErrnoMessage[] = " (errno ";
Write(2, kErrnoMessage, sizeof(kErrnoMessage) - 1);
- WriteNumber(2, saved_errno);
+ WriteNumber(2, errno_value);
Write(2, ")", 1);
}
diff --git a/test/util/logging.h b/test/util/logging.h
index 589166fab..9d224ea05 100644
--- a/test/util/logging.h
+++ b/test/util/logging.h
@@ -21,7 +21,7 @@ namespace gvisor {
namespace testing {
void CheckFailure(const char* cond, size_t cond_size, const char* msg,
- size_t msg_size, bool include_errno);
+ size_t msg_size, int errno_value);
// If cond is false, aborts the current process.
//
@@ -30,7 +30,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg,
do { \
if (!(cond)) { \
::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \
- 0, false); \
+ 0, 0); \
} \
} while (0)
@@ -41,7 +41,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg,
do { \
if (!(cond)) { \
::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \
- sizeof(msg) - 1, false); \
+ sizeof(msg) - 1, 0); \
} \
} while (0)
@@ -52,7 +52,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg,
do { \
if (!(cond)) { \
::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \
- 0, true); \
+ 0, errno); \
} \
} while (0)
@@ -63,10 +63,39 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg,
do { \
if (!(cond)) { \
::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \
- sizeof(msg) - 1, true); \
+ sizeof(msg) - 1, errno); \
} \
} while (0)
+// expr must return PosixErrorOr<T>. The current process is aborted if
+// !PosixError<T>.ok().
+//
+// This macro is async-signal-safe.
+#define TEST_CHECK_NO_ERRNO(expr) \
+ ({ \
+ auto _expr_result = (expr); \
+ if (!_expr_result.ok()) { \
+ ::gvisor::testing::CheckFailure( \
+ #expr, sizeof(#expr) - 1, nullptr, 0, \
+ _expr_result.error().errno_value()); \
+ } \
+ })
+
+// expr must return PosixErrorOr<T>. The current process is aborted if
+// !PosixError<T>.ok(). Otherwise, PosixErrorOr<T> value is returned.
+//
+// This macro is async-signal-safe.
+#define TEST_CHECK_NO_ERRNO_AND_VALUE(expr) \
+ ({ \
+ auto _expr_result = (expr); \
+ if (!_expr_result.ok()) { \
+ ::gvisor::testing::CheckFailure( \
+ #expr, sizeof(#expr) - 1, nullptr, 0, \
+ _expr_result.error().errno_value()); \
+ } \
+ std::move(_expr_result).ValueOrDie(); \
+ })
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc
index 8b676751b..a6b0de24b 100644
--- a/test/util/multiprocess_util.cc
+++ b/test/util/multiprocess_util.cc
@@ -154,6 +154,9 @@ PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) {
pid_t pid = fork();
if (pid == 0) {
fn();
+ TEST_CHECK_MSG(!::testing::Test::HasFailure(),
+ "EXPECT*/ASSERT* failed. These are not async-signal-safe "
+ "and must not be called from fn.");
_exit(0);
}
MaybeSave();
diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h
index 2f3bf4a6f..840fde4ee 100644
--- a/test/util/multiprocess_util.h
+++ b/test/util/multiprocess_util.h
@@ -123,7 +123,8 @@ inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd,
// Calls fn in a forked subprocess and returns the exit status of the
// subprocess.
//
-// fn must be async-signal-safe.
+// fn must be async-signal-safe. Use of ASSERT/EXPECT functions is prohibited.
+// Use TEST_CHECK variants instead.
PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn);
} // namespace testing
diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc
index deed0c05b..8522e4c81 100644
--- a/test/util/posix_error.cc
+++ b/test/util/posix_error.cc
@@ -50,7 +50,7 @@ std::string PosixError::ToString() const {
ret = absl::StrCat("PosixError(errno=", errno_, " ", res, ")");
#endif
- if (!msg_.empty()) {
+ if (strnlen(msg_, sizeof(msg_)) > 0) {
ret.append(" ");
ret.append(msg_);
}
diff --git a/test/util/posix_error.h b/test/util/posix_error.h
index b634a7f78..27557ad44 100644
--- a/test/util/posix_error.h
+++ b/test/util/posix_error.h
@@ -26,12 +26,18 @@
namespace gvisor {
namespace testing {
+// PosixError must be async-signal-safe.
class ABSL_MUST_USE_RESULT PosixError {
public:
PosixError() {}
+
explicit PosixError(int errno_value) : errno_(errno_value) {}
- PosixError(int errno_value, std::string msg)
- : errno_(errno_value), msg_(std::move(msg)) {}
+
+ PosixError(int errno_value, std::string_view msg) : errno_(errno_value) {
+ // Check that `msg` will fit, leaving room for '\0' at the end.
+ TEST_CHECK(msg.size() < sizeof(msg_));
+ msg.copy(msg_, msg.size());
+ }
PosixError(PosixError&& other) = default;
PosixError& operator=(PosixError&& other) = default;
@@ -45,7 +51,7 @@ class ABSL_MUST_USE_RESULT PosixError {
const PosixError& error() const { return *this; }
int errno_value() const { return errno_; }
- std::string message() const { return msg_; }
+ const char* message() const { return msg_; }
// ToString produces a full string representation of this posix error
// including the printable representation of the errno and the error message.
@@ -58,7 +64,7 @@ class ABSL_MUST_USE_RESULT PosixError {
private:
int errno_ = 0;
- std::string msg_;
+ char msg_[1024] = {};
};
template <typename T>
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 2b20457e9..fb0fc6524 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -191,7 +191,7 @@ endif
build_paths = \
(set -euo pipefail; \
$(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) 2>&1 \
- | tee /proc/self/fd/2 \
+ | tee /dev/fd/2 \
| sed -n -e '/^Target/,$$p' \
| sed -n -e '/^ \($(subst /,\/,$(subst $(SPACE),\|,$(BUILD_ROOTS)))\)/p' \
| sed -e 's/ /\n/g' \
diff --git a/website/cmd/syscalldocs/main.go b/website/cmd/syscalldocs/main.go
index 327537214..830d2bac7 100644
--- a/website/cmd/syscalldocs/main.go
+++ b/website/cmd/syscalldocs/main.go
@@ -52,6 +52,7 @@ layout: docs
category: Compatibility
weight: 50
permalink: /docs/user_guide/compatibility/{{.OS}}/{{.Arch}}/
+include_in_menu: True
---
This table is a reference of {{.OS}} syscalls for the {{.Arch}} architecture and
@@ -75,9 +76,9 @@ syscalls. {{if .Undocumented}}{{.Undocumented}} syscalls are not yet documented.
</thead>
<tbody>
{{range $i, $syscall := .Syscalls}}
- <tr>
- <td><a class="doc-table-anchor" id="{{.Name}}"></a>{{.Number}}</td>
- <td><a href="http://man7.org/linux/man-pages/man2/{{.Name}}.2.html" target="_blank" rel="noopener">{{.Name}}</a></td>
+ <tr id="{{.Name}}">
+ <td><a href="#{{.Name}}">{{.Number}}</a></td>
+ <td><a href="{{.DocURL}}" target="_blank" rel="noopener">{{.Name}}</a></td>
<td>{{.Support}}</td>
<td>{{.Note}} {{range $i, $url := .URLs}}<br/>See: <a href="{{.}}">{{.}}</a>{{end}}</td>
</tr>
@@ -92,6 +93,27 @@ func Fatalf(format string, a ...interface{}) {
os.Exit(1)
}
+// syscallDocURL returns a doc url for a given syscall, doing its best to return a url that exists.
+func syscallDocURL(name string) string {
+ customDocs := map[string]string{
+ "io_pgetevents": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "rseq": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "io_uring_setup": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_setup.2.en.html",
+ "io_uring_enter": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_enter.2.en.html",
+ "io_uring_register": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_register.2.en.html",
+ "open_tree": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "move_mount": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "fsopen": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "fsconfig": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "fsmount": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ "fspick": "https://man7.org/linux/man-pages/man2/syscalls.2.html",
+ }
+ if url, ok := customDocs[name]; ok {
+ return url
+ }
+ return fmt.Sprintf("http://man7.org/linux/man-pages/man2/%s.2.html", name)
+}
+
func main() {
inputFlag := flag.String("in", "-", "File to input ('-' for stdin)")
outputDir := flag.String("out", ".", "Directory to output files.")
@@ -145,6 +167,7 @@ func main() {
Syscalls []struct {
Name string
Number uintptr
+ DocURL string
Support string
Note string
URLs []string
@@ -161,6 +184,7 @@ func main() {
Syscalls: []struct {
Name string
Number uintptr
+ DocURL string
Support string
Note string
URLs []string
@@ -187,14 +211,16 @@ func main() {
data.Syscalls = append(data.Syscalls, struct {
Name string
Number uintptr
+ DocURL string
Support string
Note string
URLs []string
}{
Name: s.Name,
Number: num,
+ DocURL: syscallDocURL(s.Name),
Support: s.Support,
- Note: s.Note, // TODO urls
+ Note: s.Note,
URLs: s.URLs,
})
}