summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/fcntl.go2
-rw-r--r--pkg/abi/linux/fuse.go8
-rw-r--r--pkg/abi/linux/sem.go35
-rw-r--r--pkg/abi/linux/socket.go14
-rw-r--r--pkg/coverage/coverage.go130
-rw-r--r--pkg/cpuid/cpuid.go11
-rw-r--r--pkg/cpuid/cpuid_x86.go11
-rw-r--r--pkg/crypto/BUILD12
-rw-r--r--pkg/crypto/crypto.go (renamed from pkg/sleep/empty.s)5
-rw-r--r--pkg/crypto/crypto_stdlib.go (renamed from pkg/sleep/commit_amd64.s)35
-rw-r--r--pkg/flipcall/BUILD3
-rw-r--r--pkg/flipcall/packet_window_mmap_amd64.go (renamed from pkg/flipcall/packet_window_mmap.go)0
-rw-r--r--pkg/flipcall/packet_window_mmap_arm64.go27
-rw-r--r--pkg/goid/BUILD1
-rw-r--r--pkg/log/json.go14
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/merkletree/merkletree.go10
-rw-r--r--pkg/p9/client.go4
-rw-r--r--pkg/p9/client_file.go25
-rw-r--r--pkg/p9/handlers.go81
-rw-r--r--pkg/p9/p9test/client_test.go15
-rw-r--r--pkg/p9/server.go14
-rw-r--r--pkg/p9/transport_test.go16
-rw-r--r--pkg/pool/pool.go1
-rw-r--r--pkg/refsvfs2/BUILD2
-rw-r--r--pkg/refsvfs2/refs_template.go9
-rw-r--r--pkg/safemem/block_unsafe.go20
-rw-r--r--pkg/seccomp/seccomp.go2
-rw-r--r--pkg/segment/test/set_functions.go1
-rw-r--r--pkg/sentry/arch/arch.go15
-rw-r--r--pkg/sentry/arch/arch_state_x86.go17
-rw-r--r--pkg/sentry/arch/signal.go39
-rw-r--r--pkg/sentry/control/pprof.go2
-rw-r--r--pkg/sentry/control/state.go1
-rw-r--r--pkg/sentry/fdimport/fdimport.go1
-rw-r--r--pkg/sentry/fs/BUILD2
-rw-r--r--pkg/sentry/fs/copy_up.go13
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/file.go47
-rw-r--r--pkg/sentry/fs/filetest/filetest.go4
-rw-r--r--pkg/sentry/fs/fs.go2
-rw-r--r--pkg/sentry/fs/gofer/BUILD2
-rw-r--r--pkg/sentry/fs/gofer/attr.go2
-rw-r--r--pkg/sentry/fs/gofer/file.go31
-rw-r--r--pkg/sentry/fs/gofer/inode.go2
-rw-r--r--pkg/sentry/fs/inode.go6
-rw-r--r--pkg/sentry/fs/proc/sys.go1
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go26
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_control.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go10
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/file.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go42
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go12
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go74
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go43
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go24
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go23
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go6
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go67
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go178
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go494
-rw-r--r--pkg/sentry/fsmetric/BUILD10
-rw-r--r--pkg/sentry/fsmetric/fsmetric.go83
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go7
-rw-r--r--pkg/sentry/kernel/fasync/BUILD2
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go96
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/kernel.go50
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go115
-rw-r--r--pkg/sentry/kernel/shm/BUILD3
-rw-r--r--pkg/sentry/kernel/signal.go4
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go4
-rw-r--r--pkg/sentry/kernel/syslog.go6
-rw-r--r--pkg/sentry/kernel/task_exit.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go16
-rw-r--r--pkg/sentry/memmap/memmap.go2
-rw-r--r--pkg/sentry/mm/aio_context.go79
-rw-r--r--pkg/sentry/mm/aio_context_state.go4
-rw-r--r--pkg/sentry/mm/lifecycle.go2
-rw-r--r--pkg/sentry/mm/mm_test.go43
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go62
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go14
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go40
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go7
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD11
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go1
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s196
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go14
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s110
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go14
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go16
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go2
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go78
-rw-r--r--pkg/sentry/socket/hostinet/socket.go160
-rw-r--r--pkg/sentry/socket/netlink/socket.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack.go705
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go6
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go30
-rw-r--r--pkg/sentry/socket/socket.go249
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go27
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go78
-rw-r--r--pkg/sentry/socket/unix/unix.go15
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/socket.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go2
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/epoll.go10
-rw-r--r--pkg/sentry/vfs/file_description.go86
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go25
-rw-r--r--pkg/sentry/vfs/vfs.go3
-rw-r--r--pkg/sentry/watchdog/watchdog.go1
-rw-r--r--pkg/shim/v1/proc/process.go1
-rw-r--r--pkg/shim/v1/shim/BUILD1
-rw-r--r--pkg/shim/v1/shim/shim.go (renamed from pkg/sleep/commit_asm.go)13
-rw-r--r--pkg/shim/v1/utils/utils.go1
-rw-r--r--pkg/shim/v2/BUILD1
-rw-r--r--pkg/shim/v2/service.go86
-rw-r--r--pkg/sleep/BUILD4
-rw-r--r--pkg/sleep/commit_arm64.s38
-rw-r--r--pkg/sleep/commit_noasm.go33
-rw-r--r--pkg/sleep/sleep_unsafe.go27
-rw-r--r--pkg/state/tests/integer_test.go26
-rw-r--r--pkg/sync/BUILD35
-rw-r--r--pkg/sync/atomicptrmaptest/BUILD57
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap.go (renamed from pkg/syncevent/waiter_asm_unsafe.go)14
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap_test.go635
-rw-r--r--pkg/sync/checklocks_off_unsafe.go18
-rw-r--r--pkg/sync/checklocks_on_unsafe.go108
-rw-r--r--pkg/sync/generic_atomicptr_unsafe.go (renamed from pkg/sync/atomicptr_unsafe.go)4
-rw-r--r--pkg/sync/generic_atomicptrmap_unsafe.go503
-rw-r--r--pkg/sync/generic_seqatomic_unsafe.go (renamed from pkg/sync/seqatomic_unsafe.go)21
-rw-r--r--pkg/sync/goyield_go113_unsafe.go18
-rw-r--r--pkg/sync/goyield_unsafe.go (renamed from pkg/sync/spin_unsafe.go)8
-rw-r--r--pkg/sync/memmove_unsafe.go28
-rw-r--r--pkg/sync/mutex_test.go4
-rw-r--r--pkg/sync/mutex_unsafe.go51
-rw-r--r--pkg/sync/norace_unsafe.go11
-rw-r--r--pkg/sync/race_amd64.s (renamed from pkg/syncevent/waiter_amd64.s)17
-rw-r--r--pkg/sync/race_arm64.s (renamed from pkg/syncevent/waiter_arm64.s)17
-rw-r--r--pkg/sync/race_unsafe.go6
-rw-r--r--pkg/sync/runtime_unsafe.go129
-rw-r--r--pkg/sync/rwmutex_test.go2
-rw-r--r--pkg/sync/rwmutex_unsafe.go152
-rw-r--r--pkg/sync/seqcount.go34
-rw-r--r--pkg/sync/seqcount_test.go53
-rw-r--r--pkg/syncevent/BUILD4
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_unsafe.go59
-rw-r--r--pkg/syserr/host_linux.go2
-rw-r--r--pkg/syserr/netstack.go93
-rw-r--r--pkg/tcpip/checker/checker.go390
-rw-r--r--pkg/tcpip/header/BUILD5
-rw-r--r--pkg/tcpip/header/icmpv6.go13
-rw-r--r--pkg/tcpip/header/igmp.go181
-rw-r--r--pkg/tcpip/header/igmp_test.go110
-rw-r--r--pkg/tcpip/header/ipv4.go209
-rw-r--r--pkg/tcpip/header/ipv4_test.go179
-rw-r--r--pkg/tcpip/header/ipv6.go15
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go345
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go356
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go42
-rw-r--r--pkg/tcpip/header/mld.go103
-rw-r--r--pkg/tcpip/header/mld_test.go61
-rw-r--r--pkg/tcpip/header/ndp_options.go2
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go18
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go18
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go11
-rw-r--r--pkg/tcpip/link/loopback/loopback.go17
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go8
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go6
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go7
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go17
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go29
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go64
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go6
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go12
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD2
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go148
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go225
-rw-r--r--pkg/tcpip/network/ip/BUILD26
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go676
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go877
-rw-r--r--pkg/tcpip/network/ip_test.go129
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go345
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go215
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go255
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go277
-rw-r--r--pkg/tcpip/network/ipv6/BUILD18
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go57
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go194
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go285
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go310
-rw-r--r--pkg/tcpip/network/ipv6/mld.go262
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go297
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go351
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go140
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1261
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
-rw-r--r--pkg/tcpip/socketops.go341
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go179
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/forwarding_test.go16
-rw-r--r--pkg/tcpip/stack/ndp_test.go259
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go16
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go7
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go27
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go144
-rw-r--r--pkg/tcpip/stack/nic.go103
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/route.go200
-rw-r--r--pkg/tcpip/stack/stack.go147
-rw-r--r--pkg/tcpip/stack/stack_test.go217
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go14
-rw-r--r--pkg/tcpip/stack/transport_test.go34
-rw-r--r--pkg/tcpip/tcpip.go214
-rw-r--r--pkg/tcpip/tcpip_test.go44
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go193
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go25
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go82
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go45
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go94
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD6
-rw-r--r--pkg/tcpip/transport/tcp/accept.go6
-rw-r--r--pkg/tcpip/transport/tcp/connect.go42
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go288
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/reno_recovery.go67
-rw-r--r--pkg/tcpip/transport/tcp/sack_recovery.go120
-rw-r--r--pkg/tcpip/transport/tcp/segment.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go321
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go162
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go14
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go375
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go9
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go326
-rw-r--r--pkg/test/criutil/criutil.go12
-rw-r--r--pkg/test/dockerutil/container.go6
-rw-r--r--pkg/test/dockerutil/exec.go5
-rw-r--r--pkg/test/testutil/testutil.go58
-rw-r--r--pkg/usermem/usermem.go2
-rw-r--r--pkg/waiter/waiter.go11
-rw-r--r--pkg/waiter/waiter_test.go20
310 files changed, 14641 insertions, 5687 deletions
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index cc3571fad..d1ca56370 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -25,6 +25,8 @@ const (
F_SETLKW = 7
F_SETOWN = 8
F_GETOWN = 9
+ F_SETSIG = 10
+ F_GETSIG = 11
F_SETOWN_EX = 15
F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
index d91c97a64..1070b457c 100644
--- a/pkg/abi/linux/fuse.go
+++ b/pkg/abi/linux/fuse.go
@@ -19,16 +19,22 @@ import (
"gvisor.dev/gvisor/pkg/marshal/primitive"
)
+// FUSEOpcode is a FUSE operation code.
+//
// +marshal
type FUSEOpcode uint32
+// FUSEOpID is a FUSE operation ID.
+//
// +marshal
type FUSEOpID uint64
// FUSE_ROOT_ID is the id of root inode.
const FUSE_ROOT_ID = 1
-// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+// Opcodes for FUSE operations.
+//
+// Analogous to the opcodes in include/linux/fuse.h.
const (
FUSE_LOOKUP FUSEOpcode = 1
FUSE_FORGET = 2 /* no reply */
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 1b2f76c0b..2424884c1 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -32,6 +32,23 @@ const (
SEM_STAT_ANY = 20
)
+// Information about system-wide sempahore limits and parameters.
+//
+// Source: include/uapi/linux/sem.h
+const (
+ SEMMNI = 32000
+ SEMMSL = 32000
+ SEMMNS = SEMMNI * SEMMSL
+ SEMOPM = 500
+ SEMVMX = 32767
+ SEMAEM = SEMVMX
+
+ SEMUME = SEMOPM
+ SEMMNU = SEMMNS
+ SEMMAP = SEMMNS
+ SEMUSZ = 20
+)
+
const SEM_UNDO = 0x1000
// Sembuf is equivalent to struct sembuf.
@@ -42,3 +59,21 @@ type Sembuf struct {
SemOp int16
SemFlg int16
}
+
+// SemInfo is equivalent to struct seminfo.
+//
+// Source: include/uapi/linux/sem.h
+//
+// +marshal
+type SemInfo struct {
+ SemMap uint32
+ SemMni uint32
+ SemMns uint32
+ SemMnu uint32
+ SemMsl uint32
+ SemOpm uint32
+ SemUme uint32
+ SemUsz uint32
+ SemVmx uint32
+ SemAem uint32
+}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d156d41e4..556892dc3 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -111,12 +111,12 @@ type SockType int
// Socket types, from linux/net.h.
const (
SOCK_STREAM SockType = 1
- SOCK_DGRAM = 2
- SOCK_RAW = 3
- SOCK_RDM = 4
- SOCK_SEQPACKET = 5
- SOCK_DCCP = 6
- SOCK_PACKET = 10
+ SOCK_DGRAM SockType = 2
+ SOCK_RAW SockType = 3
+ SOCK_RDM SockType = 4
+ SOCK_SEQPACKET SockType = 5
+ SOCK_DCCP SockType = 6
+ SOCK_PACKET SockType = 10
)
// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
@@ -448,6 +448,8 @@ type ControlMessageCredentials struct {
// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
//
// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+//
+// +stateify savable
type ControlMessageIPPacketInfo struct {
NIC int32
LocalAddr InetAddr
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index a4f4b2c5e..fdfe31417 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -27,6 +27,7 @@ import (
"io"
"sort"
"sync/atomic"
+ "testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -34,12 +35,6 @@ import (
"github.com/bazelbuild/rules_go/go/tools/coverdata"
)
-// KcovAvailable returns whether the kcov coverage interface is available. It is
-// available as long as coverage is enabled for some files.
-func KcovAvailable() bool {
- return len(coverdata.Cover.Blocks) > 0
-}
-
// coverageMu must be held while accessing coverdata.Cover. This prevents
// concurrent reads/writes from multiple threads collecting coverage data.
var coverageMu sync.RWMutex
@@ -47,6 +42,22 @@ var coverageMu sync.RWMutex
// once ensures that globalData is only initialized once.
var once sync.Once
+// blockBitLength is the number of bits used to represent coverage block index
+// in a synthetic PC (the rest are used to represent the file index). Even
+// though a PC has 64 bits, we only use the lower 32 bits because some users
+// (e.g., syzkaller) may truncate that address to a 32-bit value.
+//
+// As of this writing, there are ~1200 files that can be instrumented and at
+// most ~1200 blocks per file, so 16 bits is more than enough to represent every
+// file and every block.
+const blockBitLength = 16
+
+// KcovAvailable returns whether the kcov coverage interface is available. It is
+// available as long as coverage is enabled for some files.
+func KcovAvailable() bool {
+ return len(coverdata.Cover.Blocks) > 0
+}
+
var globalData struct {
// files is the set of covered files sorted by filename. It is calculated at
// startup.
@@ -104,14 +115,14 @@ var coveragePool = sync.Pool{
// coverage tools, we reset the global coverage data every time this function is
// run.
func ConsumeCoverageData(w io.Writer) int {
- once.Do(initCoverageData)
+ InitCoverageData()
coverageMu.Lock()
defer coverageMu.Unlock()
total := 0
var pcBuffer [8]byte
- for fileIndex, file := range globalData.files {
+ for fileNum, file := range globalData.files {
counters := coverdata.Cover.Counters[file]
for index := 0; index < len(counters); index++ {
if atomic.LoadUint32(&counters[index]) == 0 {
@@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int {
}
// Non-zero coverage data found; consume it and report as a PC.
atomic.StoreUint32(&counters[index], 0)
- pc := globalData.syntheticPCs[fileIndex][index]
+ pc := globalData.syntheticPCs[fileNum][index]
usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
n, err := w.Write(pcBuffer[:])
if err != nil {
@@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int {
return total
}
-// initCoverageData initializes globalData. It should only be called once,
-// before any kcov data is written.
-func initCoverageData() {
- // First, order all files. Then calculate synthetic PCs for every block
- // (using the well-defined ordering for files as well).
- for file := range coverdata.Cover.Blocks {
- globalData.files = append(globalData.files, file)
+// InitCoverageData initializes globalData. It should be called before any kcov
+// data is written.
+func InitCoverageData() {
+ once.Do(func() {
+ // First, order all files. Then calculate synthetic PCs for every block
+ // (using the well-defined ordering for files as well).
+ for file := range coverdata.Cover.Blocks {
+ globalData.files = append(globalData.files, file)
+ }
+ sort.Strings(globalData.files)
+
+ for fileNum, file := range globalData.files {
+ blocks := coverdata.Cover.Blocks[file]
+ pcs := make([]uint64, 0, len(blocks))
+ for blockNum := range blocks {
+ pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum))
+ }
+ globalData.syntheticPCs = append(globalData.syntheticPCs, pcs)
+ }
+ })
+}
+
+// Symbolize prints information about the block corresponding to pc.
+func Symbolize(out io.Writer, pc uint64) error {
+ fileNum, blockNum := syntheticPCToIndexes(pc)
+ file, err := fileFromIndex(fileNum)
+ if err != nil {
+ return err
+ }
+ block, err := blockFromIndex(file, blockNum)
+ if err != nil {
+ return err
}
- sort.Strings(globalData.files)
-
- // nextSyntheticPC is the first PC that we generate for a block.
- //
- // This uses a standard-looking kernel range for simplicity.
- //
- // FIXME(b/160639712): This is only necessary because syzkaller requires
- // addresses in the kernel range. If we can remove this constraint, then we
- // should be able to use the actual addresses.
- var nextSyntheticPC uint64 = 0xffffffff80000000
- for _, file := range globalData.files {
- blocks := coverdata.Cover.Blocks[file]
- thisFile := make([]uint64, 0, len(blocks))
- for range blocks {
- thisFile = append(thisFile, nextSyntheticPC)
- nextSyntheticPC++ // Advance.
+ writeBlock(out, pc, file, block)
+ return nil
+}
+
+// WriteAllBlocks prints all information about all blocks along with their
+// corresponding synthetic PCs.
+func WriteAllBlocks(out io.Writer) {
+ for fileNum, file := range globalData.files {
+ for blockNum, block := range coverdata.Cover.Blocks[file] {
+ writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block)
}
- globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile)
}
}
+
+func calculateSyntheticPC(fileNum int, blockNum int) uint64 {
+ return (uint64(fileNum) << blockBitLength) + uint64(blockNum)
+}
+
+func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) {
+ return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1))
+}
+
+// fileFromIndex returns the name of the file in the sorted list of instrumented files.
+func fileFromIndex(i int) (string, error) {
+ total := len(globalData.files)
+ if i < 0 || i >= total {
+ return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total)
+ }
+ return globalData.files[i], nil
+}
+
+// blockFromIndex returns the i-th block in the given file.
+func blockFromIndex(file string, i int) (testing.CoverBlock, error) {
+ blocks, ok := coverdata.Cover.Blocks[file]
+ if !ok {
+ return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file)
+ }
+ total := len(blocks)
+ if i < 0 || i >= total {
+ return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total)
+ }
+ return blocks[i], nil
+}
+
+func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) {
+ io.WriteString(out, fmt.Sprintf("%#x\n", pc))
+ io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
+}
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index f7f9dbf86..69eeb7528 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -36,3 +36,14 @@ package cpuid
// On arm64, features are numbered according to the ELF HWCAP definition.
// arch/arm64/include/uapi/asm/hwcap.h
type Feature int
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 17a89c00d..392711e8f 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool {
return fs.VendorID == intelVendorID
}
-// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
-// subset of the host feature set.
-type ErrIncompatible struct {
- message string
-}
-
-// Error implements error.
-func (e ErrIncompatible) Error() string {
- return e.message
-}
-
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
diff --git a/pkg/crypto/BUILD b/pkg/crypto/BUILD
new file mode 100644
index 000000000..08fa772ca
--- /dev/null
+++ b/pkg/crypto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "crypto",
+ srcs = [
+ "crypto.go",
+ "crypto_stdlib.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/sleep/empty.s b/pkg/crypto/crypto.go
index fb37360ac..b26b55d37 100644
--- a/pkg/sleep/empty.s
+++ b/pkg/crypto/crypto.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Empty assembly file so empty func definitions work.
+// Package crypto wraps crypto primitives.
+package crypto
diff --git a/pkg/sleep/commit_amd64.s b/pkg/crypto/crypto_stdlib.go
index bc4ac2c3c..74a55a123 100644
--- a/pkg/sleep/commit_amd64.s
+++ b/pkg/crypto/crypto_stdlib.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,24 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "textflag.h"
+package crypto
-#define preparingG 1
+import (
+ "crypto/ecdsa"
+ "crypto/sha512"
+ "math/big"
+)
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVQ waitingG+8(FP), CX
- MOVQ g+0(FP), DX
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
- MOVQ $preparingG, AX
- LOCK
- CMPXCHGQ DX, 0(CX)
-
- SETEQ AX
- MOVB AX, ret+16(FP)
+// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
+// public key, pub. Its return value records whether the signature is valid.
+func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool {
+ return ecdsa.Verify(pub, hash, r, s)
+}
- RET
+// SumSha384 returns the SHA384 checksum of the data.
+func SumSha384(data []byte) (sum384 [sha512.Size384]byte) {
+ return sha512.Sum384(data)
+}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index aa8e4e1f3..cc31d0175 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,7 +11,8 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
- "packet_window_mmap.go",
+ "packet_window_mmap_amd64.go",
+ "packet_window_mmap_arm64.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go
index 869183b11..869183b11 100644
--- a/pkg/flipcall/packet_window_mmap.go
+++ b/pkg/flipcall/packet_window_mmap_amd64.go
diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/flipcall/packet_window_mmap_arm64.go
new file mode 100644
index 000000000..b9c9c44f6
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap_arm64.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
index d855b702c..08832a8ae 100644
--- a/pkg/goid/BUILD
+++ b/pkg/goid/BUILD
@@ -9,6 +9,7 @@ go_library(
"goid_amd64.s",
"goid_arm64.s",
],
+ stateify = False,
visibility = ["//visibility:public"],
)
diff --git a/pkg/log/json.go b/pkg/log/json.go
index bdf9d691e..8c52dcc87 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -27,8 +27,8 @@ type jsonLog struct {
}
// MarshalJSON implements json.Marshaler.MarashalJSON.
-func (lv Level) MarshalJSON() ([]byte, error) {
- switch lv {
+func (l Level) MarshalJSON() ([]byte, error) {
+ switch l {
case Warning:
return []byte(`"warning"`), nil
case Info:
@@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) {
case Debug:
return []byte(`"debug"`), nil
default:
- return nil, fmt.Errorf("unknown level %v", lv)
+ return nil, fmt.Errorf("unknown level %v", l)
}
}
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal
// from both string names and integers.
-func (lv *Level) UnmarshalJSON(b []byte) error {
+func (l *Level) UnmarshalJSON(b []byte) error {
switch s := string(b); s {
case "0", `"warning"`:
- *lv = Warning
+ *l = Warning
case "1", `"info"`:
- *lv = Info
+ *l = Info
case "2", `"debug"`:
- *lv = Debug
+ *l = Debug
default:
return fmt.Errorf("unknown level %q", s)
}
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 37e0605ad..2e3408357 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error {
case Warning:
f = Warningf
default:
- return fmt.Errorf("Unknown log level %v", l)
+ return fmt.Errorf("unknown log level %v", l)
}
stdlog.SetOutput(linewriter.NewWriter(func(p []byte) {
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 6acee90ef..aea7dde38 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -350,9 +350,13 @@ type VerifyParams struct {
// For verifyMetadata, params.data is not needed. It only accesses params.tree
// for the raw root hash.
func verifyMetadata(params *VerifyParams, layout *Layout) error {
- root := make([]byte, layout.digestSize)
- if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
- return fmt.Errorf("failed to read root hash: %w", err)
+ var root []byte
+ // Only read the root hash if we expect that the Merkle tree file is non-empty.
+ if params.Size != 0 {
+ root = make([]byte, layout.digestSize)
+ if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
}
descriptor := VerityDescriptor{
Name: params.Name,
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 71e944c30..eadea390a 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -570,6 +570,8 @@ func (c *Client) Version() uint32 {
func (c *Client) Close() {
// unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
// been called (by c.watch()).
- c.socket.Shutdown()
+ if err := c.socket.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err)
+ }
c.closedWg.Wait()
}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 28fe081d6..8b46a2987 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -478,28 +478,23 @@ func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
}
// Write implements part of the io.ReadWriter interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.Writer, but is more consistent with gVisor's pattern of
+// returning errors that correspond to Linux errnos. Since short writes without
+// error are common in Linux, returning a nil error is appropriate.
func (r *ReadWriterFile) Write(p []byte) (int, error) {
n, err := r.File.WriteAt(p, r.Offset)
r.Offset += uint64(n)
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return n, err
}
// WriteAt implements the io.WriteAt interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.WriterAt. See comment on Write for justification.
func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
- n, err := r.File.WriteAt(p, uint64(offset))
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return r.File.WriteAt(p, uint64(offset))
}
// Rename implements File.Rename.
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index abd237f46..81ceb37c5 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -296,25 +296,6 @@ func (t *Tlopen) handle(cs *connState) message {
}
defer ref.DecRef()
- ref.openedMu.Lock()
- defer ref.openedMu.Unlock()
-
- // Has it been opened already?
- if ref.opened || !CanOpen(ref.mode) {
- return newErr(syscall.EINVAL)
- }
-
- if ref.mode.IsDir() {
- // Directory must be opened ReadOnly.
- if t.Flags&OpenFlagsModeMask != ReadOnly {
- return newErr(syscall.EISDIR)
- }
- // Directory not truncatable.
- if t.Flags&OpenTruncate != 0 {
- return newErr(syscall.EISDIR)
- }
- }
-
var (
qid QID
ioUnit uint32
@@ -326,6 +307,22 @@ func (t *Tlopen) handle(cs *connState) message {
return syscall.EINVAL
}
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return syscall.EINVAL
+ }
+
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return syscall.EISDIR
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return syscall.EISDIR
+ }
+ }
+
osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
return err
}); err != nil {
@@ -366,7 +363,7 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -437,7 +434,7 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -476,7 +473,7 @@ func (t *Tlink) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -518,7 +515,7 @@ func (t *Trenameat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -561,7 +558,7 @@ func (t *Tunlinkat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -701,13 +698,12 @@ func (t *Tread) handle(cs *connState) message {
)
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be read? Check permissions.
- if openFlags&OpenFlagsModeMask == WriteOnly {
+ if ref.openFlags&OpenFlagsModeMask == WriteOnly {
return syscall.EPERM
}
@@ -731,13 +727,12 @@ func (t *Twrite) handle(cs *connState) message {
var n int
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EPERM
}
@@ -778,7 +773,7 @@ func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -820,7 +815,7 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -898,13 +893,12 @@ func (t *Tallocate) handle(cs *connState) message {
if err := ref.safelyWrite(func() error {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EBADF
}
@@ -1049,8 +1043,8 @@ func (t *Treaddir) handle(cs *connState) message {
return syscall.EINVAL
}
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1076,8 +1070,8 @@ func (t *Tfsync) handle(cs *connState) message {
defer ref.DecRef()
if err := ref.safelyRead(func() (err error) {
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1185,8 +1179,13 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI
}
// Has it been opened already?
- if _, opened := ref.OpenFlags(); opened {
- err = syscall.EBUSY
+ err = ref.safelyRead(func() (err error) {
+ if ref.opened {
+ return syscall.EBUSY
+ }
+ return nil
+ })
+ if err != nil {
return
}
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e605b14c..2e3d427ae 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string
// case.
defer checkDeleted(h, dst)
} else {
+ // If the type is different than the destination, then
+ // we expect the rename to fail. We expect that this
+ // is returned.
+ //
+ // If the file being renamed to itself, this is
+ // technically allowed and a no-op, but all the
+ // triggers will fire.
if !selfRename {
- // If the type is different than the
- // destination, then we expect the rename to
- // fail. We expect ensure that this is
- // returned.
expectedErr = syscall.EINVAL
- } else {
- // This is the file being renamed to itself.
- // This is technically allowed and a no-op, but
- // all the triggers will fire.
}
dst.Close()
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 3736f12a3..8c5c434fd 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -134,12 +134,11 @@ type fidRef struct {
// The node above will be closed only when refs reaches zero.
refs int64
- // openedMu protects opened and openFlags.
- openedMu sync.Mutex
-
// opened indicates whether this has been opened already.
//
// This is updated in handlers.go.
+ //
+ // opened is protected by pathNode.opMu or renameMu (for write).
opened bool
// mode is the fidRef's mode from the walk. Only the type bits are
@@ -151,6 +150,8 @@ type fidRef struct {
// openFlags is the mode used in the open.
//
// This is updated in handlers.go.
+ //
+ // openFlags is protected by pathNode.opMu or renameMu (for write).
openFlags OpenFlags
// pathNode is the current pathNode for this FID.
@@ -177,13 +178,6 @@ type fidRef struct {
deleted uint32
}
-// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
-func (f *fidRef) OpenFlags() (OpenFlags, bool) {
- f.openedMu.Lock()
- defer f.openedMu.Unlock()
- return f.openFlags, f.opened
-}
-
// IncRef increases the references on a fid.
func (f *fidRef) IncRef() {
atomic.AddInt64(&f.refs, 1)
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index e7406b374..a29f06ddb 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) {
for i := 0; i < b.N; i++ {
tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(1) {
- b.Fatalf("got tag %v expected 1", tag)
+ b.Errorf("got tag %v expected 1", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %T expected *Rflush", m)
+ b.Errorf("got message %T expected *Rflush", m)
}
if err := send(server, Tag(2), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := send(client, Tag(1), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
tag, m, err := recv(client, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(2) {
- b.Fatalf("got tag %v expected 2", tag)
+ b.Errorf("got tag %v expected 2", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %v expected *Rflush", m)
+ b.Errorf("got message %v expected *Rflush", m)
}
}
}
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index a1b2e0cfe..54e825b28 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package pool provides a trivial integer pool.
package pool
import (
diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD
index bfa1daa10..0377c0876 100644
--- a/pkg/refsvfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -9,7 +9,7 @@ go_template(
"refs_template.go",
],
opt_consts = [
- "logTrace",
+ "enableLogging",
],
types = [
"T",
diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index f64b6c6ae..3fbc91aa5 100644
--- a/pkg/refsvfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -74,11 +74,6 @@ func (r *Refs) LogRefs() bool {
return enableLogging
}
-// EnableLeakCheck enables reference leak checking on r.
-func (r *Refs) EnableLeakCheck() {
- refsvfs2.Register(r)
-}
-
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *Refs) ReadRefs() int64 {
@@ -136,7 +131,7 @@ func (r *Refs) TryIncRef() bool {
func (r *Refs) DecRef(destroy func()) {
v := atomic.AddInt64(&r.refCount, -1)
if enableLogging {
- refsvfs2.LogDecRef(r, v+1)
+ refsvfs2.LogDecRef(r, v)
}
switch {
case v < 0:
@@ -153,6 +148,6 @@ func (r *Refs) DecRef(destroy func()) {
func (r *Refs) afterLoad() {
if r.ReadRefs() > 0 {
- r.EnableLeakCheck()
+ refsvfs2.Register(r)
}
}
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index e7fd30743..7857f5853 100644
--- a/pkg/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block {
}
}
-// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is
+// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is
// safe to access without safecopy.
//
-// Preconditions: ptr+len does not overflow.
-func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, false)
+// Preconditions: ptr+length does not overflow.
+func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, false)
}
// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which
// is not safe to access without safecopy.
//
// Preconditions: ptr+len does not overflow.
-func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, true)
+func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, true)
}
-func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block {
- if uptr := uintptr(ptr); uptr+uintptr(len) < uptr {
- panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len))
+func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block {
+ if uptr := uintptr(ptr); uptr+uintptr(length) < uptr {
+ panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length))
}
return Block{
start: ptr,
- length: len,
+ length: length,
needSafecopy: needSafecopy,
}
}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index 752e2dc32..ec17ebc4d 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -79,7 +79,7 @@ func Install(rules SyscallRules) error {
// Perform the actual installation.
if errno := SetFilter(instrs); errno != 0 {
- return fmt.Errorf("Failed to set filter: %v", errno)
+ return fmt.Errorf("failed to set filter: %v", errno)
}
log.Infof("Seccomp filters installed.")
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
index 7cd895cc7..652c010da 100644
--- a/pkg/segment/test/set_functions.go
+++ b/pkg/segment/test/set_functions.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package segment is a test package.
package segment
type setFunctions struct{}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index d75d665ae..dd2effdf9 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 19ce99d25..840e53d33 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -17,27 +17,10 @@
package arch
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/usermem"
)
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
-
// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index c9fb55d00..35d2e07c3 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() {
}
}
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
}
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
}
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
@@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 {
func (s *SignalInfo) SetArch(val uint32) {
usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(usermem.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2bf3c45e1..91b8fb44f 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -193,7 +193,7 @@ func (p *Profile) StopTrace(_, _ *struct{}) error {
defer p.mu.Unlock()
if p.traceFile == nil {
- return errors.New("Execution tracing not started")
+ return errors.New("execution tracing not started")
}
// Similarly to the case above, if tasks have not ended traces, we will
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index d800f2c85..62eaca965 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
Callback: func(err error) {
if err == nil {
log.Infof("Save succeeded: exiting...")
+ s.Kernel.SetSaveSuccess(false /* autosave */)
} else {
log.Warningf("Save failed: exiting...")
s.Kernel.SetSaveError(err)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index 314661475..badd5b073 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package fdimport provides the Import function.
package fdimport
import (
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index ea85ab33c..5c3e852e9 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -49,13 +49,13 @@ go_library(
"//pkg/amutex",
"//pkg/context",
"//pkg/log",
- "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/secio",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index ff2fe6712..8e0aa9019 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err
// copyUpBuffers is a buffer pool for copying file content. The buffer
// size is the same used by io.Copy.
-var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+var copyUpBuffers = sync.Pool{
+ New: func() interface{} {
+ b := make([]byte, 8*usermem.PageSize)
+ return &b
+ },
+}
// copyContentsLocked copies the contents of lower to upper. It panics if
// less than size bytes can be copied.
@@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
- buf := copyUpBuffers.Get().([]byte)
+ buf := copyUpBuffers.Get().(*[]byte)
defer copyUpBuffers.Put(buf)
// Transfer the contents.
@@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// optimizations could be self-defeating. So we leave this as simple as possible.
var offset int64
for {
- nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset)
if err != nil && err != io.EOF {
return err
}
@@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
}
return nil
}
- nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index c7a11eec1..e04784db2 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) {
wg.Add(1)
go func(o *overlayTestFile) {
if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil {
- t.Fatalf("failed to copy up: %v", err)
+ t.Errorf("failed to copy up: %v", err)
}
wg.Done()
}(file)
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 72ea70fcf..57f904801 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -17,13 +17,12 @@ package fs
import (
"math"
"sync/atomic"
- "time"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
@@ -33,28 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- // RecordWaitTime controls writing metrics for filesystem reads.
- // Enabling this comes at a small CPU cost due to performing two
- // monotonic clock reads per read call.
- //
- // Note that this is only performed in the direct read path, and may
- // not be consistently applied for other forms of reads, such as
- // splice.
- RecordWaitTime = false
-
- reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
- readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
-)
-
-// IncrementWait increments the given wait time metric, if enabled.
-func IncrementWait(m *metric.Uint64Metric, start time.Time) {
- if !RecordWaitTime {
- return
- }
- m.IncrementBy(uint64(time.Since(start)))
-}
-
// FileMaxOffset is the maximum possible file offset.
const FileMaxOffset = math.MaxInt64
@@ -257,22 +234,19 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
//
// Returns syserror.ErrInterrupted if reading was interrupted.
func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) {
- var start time.Time
- if RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+
if !f.mu.Lock(ctx) {
- IncrementWait(readWait, start)
return 0, syserror.ErrInterrupted
}
- reads.Increment()
+ fsmetric.Reads.Increment()
n, err := f.FileOperations.Read(ctx, f, dst, f.offset)
if n > 0 && !f.flags.NonSeekable {
atomic.AddInt64(&f.offset, n)
}
f.mu.Unlock()
- IncrementWait(readWait, start)
return n, err
}
@@ -282,19 +256,16 @@ func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error)
//
// Otherwise same as Readv.
func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+
if !f.mu.Lock(ctx) {
- IncrementWait(readWait, start)
return 0, syserror.ErrInterrupted
}
- reads.Increment()
+ fsmetric.Reads.Increment()
n, err := f.FileOperations.Read(ctx, f, dst, offset)
f.mu.Unlock()
- IncrementWait(readWait, start)
return n, err
}
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
index 8049538f2..ec3d3f96c 100644
--- a/pkg/sentry/fs/filetest/filetest.go
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File {
// Read just fails the request.
func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Readv not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Read not implemented")
}
// Write just fails the request.
func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Writev not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Write not implemented")
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index d2dbff268..a020da53b 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -65,7 +65,7 @@ var (
// runs with the lock held for reading. AsyncBarrier will take the lock
// for writing, thus ensuring that all Async work completes before
// AsyncBarrier returns.
- workMu sync.RWMutex
+ workMu sync.CrossGoroutineRWMutex
// asyncError is used to store up to one asynchronous execution error.
asyncError = make(chan error, 1)
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index fea135eea..4c30098cd 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -28,7 +28,6 @@ go_library(
"//pkg/context",
"//pkg/fd",
"//pkg/log",
- "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/safemem",
@@ -38,6 +37,7 @@ go_library(
"//pkg/sentry/fs/fdpipe",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/host",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index d481baf77..e5579095b 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType {
return fs.BlockDevice
case pattr.Mode.IsSocket():
return fs.Socket
- case pattr.Mode.IsRegular():
- fallthrough
default:
return fs.RegularFile
}
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index c0bc63a32..bb63448cb 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -21,27 +21,17 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.")
- opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.")
- opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.")
- reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
- readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
- readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
- readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
-)
-
// fileOperations implements fs.FileOperations for a remote file system.
//
// +stateify savable
@@ -101,14 +91,14 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
if flags.Write {
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
- opensWX.Increment()
+ fsmetric.GoferOpensWX.Increment()
log.Warningf("Opened a writable executable: %q", name)
}
}
if handles.Host != nil {
- opensHost.Increment()
+ fsmetric.GoferOpensHost.Increment()
} else {
- opens9P.Increment()
+ fsmetric.GoferOpens9P.Increment()
}
return fs.NewFile(ctx, dirent, flags, f)
}
@@ -278,20 +268,17 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
// use this function rather than using a defer in Read() to avoid the performance hit of defer.
func (f *fileOperations) incrementReadCounters(start time.Time) {
if f.handles.Host != nil {
- readsHost.Increment()
- fs.IncrementWait(readWaitHost, start)
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
} else {
- reads9P.Increment()
- fs.IncrementWait(readWait9P, start)
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
}
}
// Read implements fs.FileOperations.Read.
func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if fs.RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
f.incrementReadCounters(start)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 3a225fd39..9d6fdd08f 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -117,7 +117,7 @@ type inodeFileState struct {
// loading is acquired when the inodeFileState begins an asynchronous
// load. It releases when the load is complete. Callers that require all
// state to be available should call waitForLoad() to ensure that.
- loading sync.Mutex `state:".(struct{})"`
+ loading sync.CrossGoroutineMutex `state:".(struct{})"`
// savedUAttr is only allocated during S/R. It points to the save-time
// unstable attributes and is used to validate restore-time ones.
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 004910453..9b3d8166a 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -18,9 +18,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -28,8 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
-
// Inode is a file system object that can be simultaneously referenced by different
// components of the VFS (Dirent, fs.File, etc).
//
@@ -247,7 +245,7 @@ func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File,
if i.overlay != nil {
return overlayGetFile(ctx, i.overlay, d, flags)
}
- opens.Increment()
+ fsmetric.Opens.Increment()
return i.InodeOperations.GetFile(ctx, d, flags)
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index f8aad2dbd..b998fb75d 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
children := map[string]*fs.Inode{
"hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil),
+ "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index aa7199014..b521a86a2 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -15,12 +15,12 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
- "//pkg/metric",
"//pkg/safemem",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index d6c65301c..e04cd608d 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -18,14 +18,13 @@ import (
"fmt"
"io"
"math"
- "time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -35,13 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-var (
- opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
- opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
- reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
- readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
-)
-
// fileInodeOperations implements fs.InodeOperations for a regular tmpfs file.
// These files are backed by pages allocated from a platform.Memory, and may be
// directly mapped.
@@ -157,9 +149,9 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare
// GetFile implements fs.InodeOperations.GetFile.
func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
if flags.Write {
- opensW.Increment()
+ fsmetric.TmpfsOpensW.Increment()
} else if flags.Read {
- opensRO.Increment()
+ fsmetric.TmpfsOpensRO.Increment()
}
flags.Pread = true
flags.Pwrite = true
@@ -319,14 +311,12 @@ func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) {
}
func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if fs.RecordWaitTime {
- start = time.Now()
- }
- reads.Increment()
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start)
+ fsmetric.TmpfsReads.Increment()
+
// Zero length reads for tmpfs are no-ops.
if dst.NumBytes() == 0 {
- fs.IncrementWait(readWait, start)
return 0, nil
}
@@ -343,7 +333,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm
size := f.attr.Size
f.dataMu.RUnlock()
if offset >= size {
- fs.IncrementWait(readWait, start)
return 0, io.EOF
}
@@ -354,7 +343,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm
f.attr.AccessTime = ktime.NowFromContext(ctx)
f.attrMu.Unlock()
}
- fs.IncrementWait(readWait, start)
return n, err
}
diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go
index 1b3459c1d..4ab894965 100644
--- a/pkg/sentry/fsimpl/fuse/connection_control.go
+++ b/pkg/sentry/fsimpl/fuse/connection_control.go
@@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
Flags: fuseDefaultInitFlags,
}
- req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
// Since there is no task to block on and FUSE_INIT is the request
// to unblock other requests, use nil.
return conn.CallAsync(nil, req)
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 91d16c1cf..d8b0d7657 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) {
var futNormal []*futureResponse
for i := 0; i < int(numRequests); i++ {
- req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
fut, err := conn.callFutureLocked(task, req)
if err != nil {
t.Fatalf("callFutureLocked failed: %v", err)
@@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) {
}
// After abort, Call() should return directly with ENOTCONN.
- req, err := conn.NewRequest(creds, 0, 0, 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, 0, 0, 0, testObj)
_, err = conn.Call(task, req)
if err != syserror.ENOTCONN {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index 89c3ef079..1bbe6fdb7 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -363,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask {
var ready waiter.EventMask
- if fd.fs.umounted {
+ if fd.fs == nil || fd.fs.umounted {
ready |= waiter.EventErr
return ready & mask
}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 95c475a65..bb2d0d31a 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
data: rand.Uint32(),
}
- req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
// Queue up a request.
// Analogous to Call except it doesn't block on the task.
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
index 8f220a04b..fcc5d9a2a 100644
--- a/pkg/sentry/fsimpl/fuse/directory.go
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent
}
// TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS.
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
res, err := fusefs.conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go
index 83f2816b7..e138b11f8 100644
--- a/pkg/sentry/fsimpl/fuse/file.go
+++ b/pkg/sentry/fsimpl/fuse/file.go
@@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) {
opcode = linux.FUSE_RELEASE
}
kernelTask := kernel.TaskFromContext(ctx)
- // ignoring errors and FUSE server reply is analogous to Linux's behavior.
- req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
- if err != nil {
- // No way to invoke Call() with an errored request.
- return
- }
+ // Ignoring errors and FUSE server reply is analogous to Linux's behavior.
+ req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
// The reply will be ignored since no callback is defined in asyncCallBack().
conn.CallAsync(kernelTask, req)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 23e827f90..3af807a21 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- return nil, nil, err
+ log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ return nil, nil, syserror.EINVAL
}
kernelTask := kernel.TaskFromContext(ctx)
@@ -360,12 +361,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
in.Flags &= ^uint32(linux.O_TRUNC)
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
- if err != nil {
- return nil, err
- }
-
// Send the request and receive the reply.
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -485,10 +482,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
return syserror.EINVAL
}
in := linux.FUSEUnlinkIn{Name: name}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
- if err != nil {
- return err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return err
@@ -515,11 +509,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
in := linux.FUSERmDirIn{Name: name}
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return err
@@ -535,10 +525,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
return nil, syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
- if err != nil {
- return nil, err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -574,10 +561,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
return "", syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
- if err != nil {
- return "", err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return "", err
@@ -680,11 +664,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
GetAttrFlags: flags,
Fh: fh,
}
- req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
- if err != nil {
- return linux.FUSEAttr{}, err
- }
-
+ req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return linux.FUSEAttr{}, err
@@ -803,11 +783,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
UID: opts.Stat.UID,
GID: opts.Stat.GID,
}
- req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
res, err := conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 2d396e84c..23ce91849 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
in.Size = pagesCanRead << usermem.PageShift
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
- if err != nil {
- return nil, 0, err
- }
-
// TODO(gvisor.dev/issue/3247): support async read.
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
res, err := fs.conn.Call(t, req)
if err != nil {
return nil, 0, err
@@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
- if err != nil {
- return 0, err
- }
-
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
req.payload = data[written : written+toWrite]
// TODO(gvisor.dev/issue/3247): support async write.
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 7fa00569b..41d679358 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
+ _ = src // Remove unused warning.
}
// SizeBytes is the size of the payload of the FUSE_INIT response.
@@ -104,7 +105,7 @@ type Request struct {
}
// NewRequest creates a new request that can be sent to the FUSE server.
-func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request {
conn.fd.mu.Lock()
defer conn.fd.mu.Unlock()
conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
@@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint
id: hdr.Unique,
hdr: &hdr,
data: buf,
- }, nil
+ }
}
// futureResponse represents an in-flight request, that may or may not have
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4c3e9acf8..807b6ed1f 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -59,6 +59,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 2294c490e..df27554d3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
@@ -985,14 +986,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
switch d.fileType() {
case linux.S_IFREG:
if !d.fs.opts.regularFilesUseSpecialFileFD {
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil {
+ if err := d.ensureSharedHandle(ctx, ats.MayRead(), ats.MayWrite(), trunc); err != nil {
return nil, err
}
- fd := &regularFileFD{}
- fd.LockFD.Init(&d.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
- AllowDirectIO: true,
- }); err != nil {
+ fd, err := newRegularFileFD(mnt, d, opts.Flags)
+ if err != nil {
return nil, err
}
vfd = &fd.vfsfd
@@ -1019,6 +1017,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
+ if atomic.LoadInt32(&d.readFD) >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
return &fd.vfsfd, nil
case linux.S_IFLNK:
// Can't open symlinks without O_PATH (which is unimplemented).
@@ -1110,7 +1113,7 @@ retry:
return nil, err
}
}
- fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, d, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
@@ -1205,11 +1208,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
// Finally, construct a file description representing the created file.
var childVFSFD *vfs.FileDescription
if useRegularFileFD {
- fd := &regularFileFD{}
- fd.LockFD.Init(&child.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
- AllowDirectIO: true,
- }); err != nil {
+ fd, err := newRegularFileFD(mnt, child, opts.Flags)
+ if err != nil {
return nil, err
}
childVFSFD = &fd.vfsfd
@@ -1221,7 +1221,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
if fdobj != nil {
h.fd = int32(fdobj.Release())
}
- fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 75a836899..3cdb1e659 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -743,7 +743,9 @@ type dentry struct {
// for memory mappings. If mmapFD is -1, no such FD is available, and the
// internal page cache implementation is used for memory mappings instead.
//
- // These fields are protected by handleMu.
+ // These fields are protected by handleMu. readFD, writeFD, and mmapFD are
+ // additionally written using atomic memory operations, allowing them to be
+ // read (albeit racily) with atomic.LoadInt32() without locking handleMu.
//
// readFile and writeFile may or may not represent the same p9.File. Once
// either p9.File transitions from closed (isNil() == true) to open
@@ -1351,16 +1353,11 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
return
}
if refs > 0 {
- if d.cached {
- // This isn't strictly necessary (fs.cachedDentries is permitted to
- // contain dentries with non-zero refs, which are skipped by
- // fs.evictCachedDentryLocked() upon reaching the end of the LRU),
- // but since we are already holding fs.renameMu for writing we may
- // as well.
- d.fs.cachedDentries.Remove(d)
- d.fs.cachedDentriesLen--
- d.cached = false
- }
+ // This isn't strictly necessary (fs.cachedDentries is permitted to
+ // contain dentries with non-zero refs, which are skipped by
+ // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but
+ // since we are already holding fs.renameMu for writing we may as well.
+ d.removeFromCacheLocked()
return
}
// Deleted and invalidated dentries with zero references are no longer
@@ -1369,20 +1366,18 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
if d.isDeleted() {
d.watches.HandleDeletion(ctx)
}
- if d.cached {
- d.fs.cachedDentries.Remove(d)
- d.fs.cachedDentriesLen--
- d.cached = false
- }
+ d.removeFromCacheLocked()
d.destroyLocked(ctx)
return
}
- // If d still has inotify watches and it is not deleted or invalidated, we
- // cannot cache it and allow it to be evicted. Otherwise, we will lose its
- // watches, even if a new dentry is created for the same file in the future.
- // Note that the size of d.watches cannot concurrently transition from zero
- // to non-zero, because adding a watch requires holding a reference on d.
+ // If d still has inotify watches and it is not deleted or invalidated, it
+ // can't be evicted. Otherwise, we will lose its watches, even if a new
+ // dentry is created for the same file in the future. Note that the size of
+ // d.watches cannot concurrently transition from zero to non-zero, because
+ // adding a watch requires holding a reference on d.
if d.watches.Size() > 0 {
+ // As in the refs > 0 case, this is not strictly necessary.
+ d.removeFromCacheLocked()
return
}
@@ -1413,6 +1408,15 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
}
}
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) removeFromCacheLocked() {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+}
+
// Precondition: fs.renameMu must be locked for writing; it may be temporarily
// unlocked.
func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
@@ -1426,12 +1430,10 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
// * fs.cachedDentriesLen != 0.
func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
victim := fs.cachedDentries.Back()
- fs.cachedDentries.Remove(victim)
- fs.cachedDentriesLen--
- victim.cached = false
- // victim.refs may have become non-zero from an earlier path resolution
- // since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 {
+ victim.removeFromCacheLocked()
+ // victim.refs or victim.watches.Size() may have become non-zero from an
+ // earlier path resolution since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 {
if victim.parent != nil {
victim.parent.dirMu.Lock()
if !victim.vfsd.IsDead() {
@@ -1668,7 +1670,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
}
fdsToClose = append(fdsToClose, d.readFD)
invalidateTranslations = true
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
} else {
// Otherwise, we want to avoid invalidating existing
// memmap.Translations (which is expensive); instead, use
@@ -1689,15 +1691,15 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
h.fd = d.readFD
}
} else {
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
}
if d.writeFD != h.fd && d.writeFD >= 0 {
fdsToClose = append(fdsToClose, d.writeFD)
}
- d.writeFD = h.fd
- d.mmapFD = h.fd
+ atomic.StoreInt32(&d.writeFD, h.fd)
+ atomic.StoreInt32(&d.mmapFD, h.fd)
} else if openReadable && d.readFD < 0 {
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
// If the file has not been opened for writing, the new FD may
// be used for read-only memory mappings. If the file was
// previously opened for reading (without an FD), then existing
@@ -1705,10 +1707,10 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// invalidate those mappings.
if d.writeFile.isNil() {
invalidateTranslations = !d.readFile.isNil()
- d.mmapFD = h.fd
+ atomic.StoreInt32(&d.mmapFD, h.fd)
}
} else if openWritable && d.writeFD < 0 {
- d.writeFD = h.fd
+ atomic.StoreInt32(&d.writeFD, h.fd)
if d.readFD >= 0 {
// We have an existing read-only FD, but the file has just
// been opened for writing, so we need to start supporting
@@ -1717,7 +1719,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// writable memory mappings. Switch to using the internal
// page cache.
invalidateTranslations = true
- d.mmapFD = -1
+ atomic.StoreInt32(&d.mmapFD, -1)
}
} else {
// The new FD is not useful.
@@ -1729,7 +1731,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// memory mappings. However, we have no writable host FD. Switch to
// using the internal page cache.
invalidateTranslations = true
- d.mmapFD = -1
+ atomic.StoreInt32(&d.mmapFD, -1)
}
// Switch to new fids.
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 652142ecc..283b220bb 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
@@ -48,6 +49,25 @@ type regularFileFD struct {
off int64
}
+func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, error) {
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
+ fsmetric.GoferOpensWX.Increment()
+ }
+ if atomic.LoadInt32(&d.mmapFD) >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
+ return fd, nil
+}
+
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *regularFileFD) Release(context.Context) {
}
@@ -89,6 +109,18 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ d := fd.dentry()
+ defer func() {
+ if atomic.LoadInt32(&d.readFD) >= 0 {
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
+ } else {
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
+ }
+ }()
+
if offset < 0 {
return 0, syserror.EINVAL
}
@@ -102,7 +134,6 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// Check for reading at EOF before calling into MM (but not under
// InteropModeShared, which makes d.size unreliable).
- d := fd.dentry()
if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) {
return 0, io.EOF
}
@@ -647,10 +678,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
// Whether or not we have a host FD, we're not allowed to use it.
return syserror.ENODEV
}
- d.handleMu.RLock()
- haveFD := d.mmapFD >= 0
- d.handleMu.RUnlock()
- if !haveFD {
+ if atomic.LoadInt32(&d.mmapFD) < 0 {
return syserror.ENODEV
}
default:
@@ -668,10 +696,7 @@ func (d *dentry) mayCachePages() bool {
if d.fs.opts.forcePageCache {
return true
}
- d.handleMu.RLock()
- haveFD := d.mmapFD >= 0
- d.handleMu.RUnlock()
- return haveFD
+ return atomic.LoadInt32(&d.mmapFD) >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 625400c0b..089955a96 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -70,7 +71,7 @@ type specialFileFD struct {
buf []byte
}
-func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) {
ftype := d.fileType()
seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK
haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
@@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
seekable: seekable,
haveQueue: haveQueue,
}
- fd.LockFD.Init(locks)
+ fd.LockFD.Init(&d.locks)
if haveQueue {
if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
return nil, err
@@ -98,6 +99,14 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
d.fs.syncMu.Lock()
d.fs.specialFileFDs[fd] = struct{}{}
d.fs.syncMu.Unlock()
+ if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
+ fsmetric.GoferOpensWX.Increment()
+ }
+ if h.fd >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
return fd, nil
}
@@ -161,6 +170,17 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ defer func() {
+ if fd.handle.fd >= 0 {
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
+ } else {
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
+ }
+ }()
+
if fd.seekable && offset < 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index c14abcff4..565d723f0 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -286,7 +286,7 @@ func (d *Dentry) cacheLocked(ctx context.Context) {
refs := atomic.LoadInt64(&d.refs)
if refs == -1 {
// Dentry has already been destroyed.
- panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d))
+ return
}
if refs > 0 {
if d.cached {
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 469f3a33d..27b00cf6f 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -16,7 +16,6 @@ package overlay
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
return err
}
defer newFD.DecRef(ctx)
- bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
- for {
- readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
- if readErr != nil && readErr != io.EOF {
- cleanupUndoCopyUp()
- return readErr
- }
- total := int64(0)
- for total < readN {
- writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
- total += writeN
- if writeErr != nil {
- cleanupUndoCopyUp()
- return writeErr
- }
- }
- if readErr == io.EOF {
- break
- }
+ if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil {
+ cleanupUndoCopyUp()
+ return err
}
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 2b89a7a6d..25c785fd4 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript
for e, mask := range fd.lowerWaiters {
fd.cachedFD.EventUnregister(e)
upperFD.EventRegister(e, mask)
- if ready&mask != 0 {
- e.Callback.Callback(e)
+ if m := ready & mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 5cf8a071a..d4f6a5a9b 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -208,7 +208,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
for _, se := range n.kernel.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
@@ -351,7 +351,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
for _, se := range k.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
sops, ok := s.Impl().(socket.SocketVFS2)
@@ -516,7 +516,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
for _, se := range d.kernel.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
sops, ok := s.Impl().(socket.SocketVFS2)
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 7c7afdcfa..25c407d98 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
"shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
"shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 10f1452ef..246bd87bc 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package signalfd provides basic signalfd file implementations.
package signalfd
import (
@@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index fe520b6fd..09957c2b7 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -67,6 +67,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e39cd305b..9296db2fb 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -381,6 +382,8 @@ afterTrailingSymlink:
creds := rp.Credentials()
child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
parentDir.insertChildLocked(child, name)
+ child.IncRef()
+ defer child.DecRef(ctx)
unlock()
fd, err := child.open(ctx, rp, &opts, true)
if err != nil {
@@ -437,6 +440,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
}
+ if fd.vfsfd.IsWritable() {
+ fsmetric.TmpfsOpensW.Increment()
+ } else if fd.vfsfd.IsReadable() {
+ fsmetric.TmpfsOpensRO.Increment()
+ }
return &fd.vfsfd, nil
case *directory:
// Can't open directories writably.
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index f8e0cffb0..6255a7c84 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -359,6 +360,10 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start)
+ fsmetric.TmpfsReads.Increment()
+
if offset < 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index add5dd48e..a4ad625bb 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -107,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// Dentries which may have a reference count of zero, and which therefore
// should be dropped once traversal is complete, are appended to ds.
//
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
-// !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+// * !rp.Done().
func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -158,15 +160,19 @@ afterSymlink:
return child, nil
}
-// verifyChild verifies the hash of child against the already verified hash of
-// the parent to ensure the child is expected. verifyChild triggers a sentry
-// panic if unexpected modifications to the file system are detected. In
-// noCrashOnVerificationFailure mode it returns a syserror instead.
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// verifyChildLocked verifies the hash of child against the already verified
+// hash of the parent to ensure the child is expected. verifyChild triggers a
+// sentry panic if unexpected modifications to the file system are detected. In
+// ErrorOnViolation mode it returns a syserror instead.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+//
// TODO(b/166474175): Investigate all possible errors returned in this
// function, and make sure we differentiate all errors that indicate unexpected
// modifications to the file system from the ones that are not harmful.
-func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
+func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -248,7 +254,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: parentMerkleFD,
Ctx: ctx,
}
@@ -268,7 +274,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contain the hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ parent.hashMu.RLock()
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
Out: &buf,
File: &fdReader,
Tree: &fdReader,
@@ -284,21 +291,27 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
- }); err != nil && err != io.EOF {
+ })
+ parent.hashMu.RUnlock()
+ if err != nil && err != io.EOF {
return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
+ child.hashMu.Lock()
if len(child.hash) == 0 {
child.hash = buf.Bytes()
}
+ child.hashMu.Unlock()
return child, nil
}
-// verifyStatAndChildren verifies the stat and children names against the
+// verifyStatAndChildrenLocked verifies the stat and children names against the
// verified hash. The mode/uid/gid and childrenNames of the file is cached
// after verified.
-func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat linux.Statx) error {
+//
+// Preconditions: d.dirMu must be locked.
+func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -384,12 +397,13 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
}
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd,
Ctx: ctx,
}
var buf bytes.Buffer
+ d.hashMu.RLock()
params := &merkletree.VerifyParams{
Out: &buf,
Tree: &fdReader,
@@ -407,6 +421,7 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
Expected: d.hash,
DataAndTreeInSameFile: false,
}
+ d.hashMu.RUnlock()
if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
params.DataAndTreeInSameFile = true
}
@@ -421,7 +436,9 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
return nil
}
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
// If verity is enabled on child, we should check again whether
@@ -470,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// be cached before enabled.
if fs.allowRuntimeEnable {
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
return nil, err
}
}
@@ -486,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
if err != nil {
return nil, err
}
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
return nil, err
}
}
@@ -506,7 +523,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
return child, nil
}
-// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -597,13 +616,13 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
if child.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
child.destroyLocked(ctx)
return nil, err
}
@@ -617,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// rp.Start().Impl().(*dentry)). It does not check that the returned directory
// is searchable by the provider of rp.
//
-// Preconditions: fs.renameMu must be locked. !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * !rp.Done().
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
@@ -958,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ d.dirMu.Lock()
if d.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return linux.Statx{}, err
}
}
+ d.dirMu.Unlock()
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 87dabe038..66029c64d 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -19,6 +19,18 @@
// The verity file system is read-only, except for one case: when
// allowRuntimeEnable is true, additional Merkle files can be generated using
// the FS_IOC_ENABLE_VERITY ioctl.
+//
+// Lock order:
+//
+// filesystem.renameMu
+// dentry.dirMu
+// fileDescription.mu
+// filesystem.verityMu
+// dentry.hashMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
package verity
import (
@@ -52,6 +64,10 @@ const (
// tree file for "/foo" is "/.merkle.verity.foo".
merklePrefix = ".merkle.verity."
+ // merkleRootPrefix is the prefix of the Merkle tree root file. This
+ // needs to be different from merklePrefix to avoid name collision.
+ merkleRootPrefix = ".merkleroot.verity."
+
// merkleOffsetInParentXattr is the extended attribute name specifying the
// offset of the child hash in its parent's Merkle tree.
merkleOffsetInParentXattr = "user.merkle.offset"
@@ -76,13 +92,8 @@ const (
)
var (
- // noCrashOnVerificationFailure indicates whether the sandbox should panic
- // whenever verification fails. If true, an error is returned instead of
- // panicking. This should only be set for tests.
- //
- // TODO(b/165661693): Decide whether to panic or return error based on this
- // flag.
- noCrashOnVerificationFailure bool
+ // action specifies the action towards detected violation.
+ action ViolationAction
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
@@ -93,6 +104,18 @@ var (
// content.
type HashAlgorithm int
+// ViolationAction is a type specifying the action when an integrity violation
+// is detected.
+type ViolationAction int
+
+const (
+ // PanicOnViolation terminates the sentry on detected violation.
+ PanicOnViolation ViolationAction = 0
+ // ErrorOnViolation returns an error from the violating system call on
+ // detected violation.
+ ErrorOnViolation = 1
+)
+
// Currently supported hashing algorithms include SHA256 and SHA512.
const (
SHA256 HashAlgorithm = iota
@@ -187,10 +210,8 @@ type InternalFilesystemOptions struct {
// system wrapped by verity file system.
LowerGetFSOptions vfs.GetFilesystemOptions
- // NoCrashOnVerificationFailure indicates whether the sandbox should
- // panic whenever verification fails. If true, an error is returned
- // instead of panicking. This should only be set for tests.
- NoCrashOnVerificationFailure bool
+ // Action specifies the action on an integrity violation.
+ Action ViolationAction
}
// Name implements vfs.FilesystemType.Name.
@@ -202,10 +223,10 @@ func (FilesystemType) Name() string {
func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
-// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+// unexpected modification to the file system is detected. In ErrorOnViolation
+// mode, it returns EIO, otherwise it panic.
func alertIntegrityViolation(msg string) error {
- if noCrashOnVerificationFailure {
+ if action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -218,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure
+ action = iopts.Action
// Mount the lower file system. The lower file system is wrapped inside
// verity, and should not be exposed or connected.
@@ -246,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merklePrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -372,12 +393,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
}
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return nil, nil, err
}
}
+ d.hashMu.Lock()
copy(d.hash, iopts.RootHash)
+ d.hashMu.Unlock()
d.vfsd.Init(d)
fs.rootDentry = d
@@ -402,7 +425,8 @@ type dentry struct {
fs *filesystem
// mode, uid, gid and size are the file mode, owner, group, and size of
- // the file in the underlying file system.
+ // the file in the underlying file system. They are set when a dentry
+ // is initialized, and never modified.
mode uint32
uid uint32
gid uint32
@@ -425,18 +449,22 @@ type dentry struct {
// childrenNames stores the name of all children of the dentry. This is
// used by verity to check whether a child is expected. This is only
- // populated by enableVerity.
+ // populated by enableVerity. childrenNames is also protected by dirMu.
childrenNames map[string]struct{}
- // lowerVD is the VirtualDentry in the underlying file system.
+ // lowerVD is the VirtualDentry in the underlying file system. It is
+ // never modified after initialized.
lowerVD vfs.VirtualDentry
// lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree
- // in the underlying file system.
+ // in the underlying file system. It is never modified after
+ // initialized.
lowerMerkleVD vfs.VirtualDentry
- // hash is the calculated hash for the current file or directory.
- hash []byte
+ // hash is the calculated hash for the current file or directory. hash
+ // is protected by hashMu.
+ hashMu sync.RWMutex `state:"nosave"`
+ hash []byte
}
// newDentry creates a new dentry representing the given verity file. The
@@ -519,7 +547,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) {
// destroyLocked destroys the dentry.
//
-// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+// Preconditions:
+// * d.fs.renameMu must be locked for writing.
+// * d.refs == 0.
func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
@@ -599,6 +629,8 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
// mode, it returns true if the target has been enabled with
// ioctl(FS_IOC_ENABLE_VERITY).
func (d *dentry) verityEnabled() bool {
+ d.hashMu.RLock()
+ defer d.hashMu.RUnlock()
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
@@ -678,11 +710,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ fd.d.dirMu.Lock()
if fd.d.verityEnabled() {
- if err := fd.d.fs.verifyStatAndChildren(ctx, fd.d, stat); err != nil {
+ if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil {
return linux.Statx{}, err
}
}
+ fd.d.dirMu.Unlock()
return stat, nil
}
@@ -718,22 +752,24 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
return offset, nil
}
-// generateMerkle generates a Merkle tree file for fd. If fd points to a file
-// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
-// of the generated Merkle tree and the data size is returned. If fd points to
-// a regular file, the data is the content of the file. If fd points to a
-// directory, the data is all hahes of its children, written to the Merkle tree
-// file.
-func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
- fdReader := vfs.FileReadWriteSeeker{
+// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a
+// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The
+// hash of the generated Merkle tree and the data size is returned. If fd
+// points to a regular file, the data is the content of the file. If fd points
+// to a directory, the data is all hashes of its children, written to the Merkle
+// tree file.
+//
+// Preconditions: fd.d.fs.verityMu must be locked.
+func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) {
+ fdReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
- merkleWriter := vfs.FileReadWriteSeeker{
+ merkleWriter := FileReadWriteSeeker{
FD: fd.merkleWriter,
Ctx: ctx,
}
@@ -793,11 +829,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
return hash, uint64(params.Size), err
}
-// recordChildren writes the names of fd's children into the corresponding
-// Merkle tree file, and saves the offset/size of the map into xattrs.
+// recordChildrenLocked writes the names of fd's children into the
+// corresponding Merkle tree file, and saves the offset/size of the map into
+// xattrs.
//
-// Preconditions: fd.d.isDir() == true
-func (fd *fileDescription) recordChildren(ctx context.Context) error {
+// Preconditions:
+// * fd.d.fs.verityMu must be locked.
+// * fd.d.isDir() == true.
+func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error {
// Record the children names in the Merkle tree file.
childrenNames, err := json.Marshal(fd.d.childrenNames)
if err != nil {
@@ -847,7 +886,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
- hash, dataSize, err := fd.generateMerkle(ctx)
+ hash, dataSize, err := fd.generateMerkleLocked(ctx)
if err != nil {
return 0, err
}
@@ -888,11 +927,13 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
}
if fd.d.isDir() {
- if err := fd.recordChildren(ctx); err != nil {
+ if err := fd.recordChildrenLocked(ctx); err != nil {
return 0, err
}
}
- fd.d.hash = append(fd.d.hash, hash...)
+ fd.d.hashMu.Lock()
+ fd.d.hash = hash
+ fd.d.hashMu.Unlock()
return 0, nil
}
@@ -904,6 +945,9 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
}
var metadata linux.DigestMetadata
+ fd.d.hashMu.RLock()
+ defer fd.d.hashMu.RUnlock()
+
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
// verity is not enabled for the file. If allowRuntimeEnable is false,
// this is an integrity violation because all files should have verity
@@ -940,11 +984,13 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) {
f := int32(0)
+ fd.d.hashMu.RLock()
// All enabled files should store a hash. This flag is not settable via
// FS_IOC_SETFLAGS.
if len(fd.d.hash) != 0 {
f |= linux.FS_VERITY_FL
}
+ fd.d.hashMu.RUnlock()
t := kernel.TaskFromContext(ctx)
if t == nil {
@@ -1013,16 +1059,17 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
- dataReader := vfs.FileReadWriteSeeker{
+ dataReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
+ fd.d.hashMu.RLock()
n, err := merkletree.Verify(&merkletree.VerifyParams{
Out: dst.Writer(ctx),
File: &dataReader,
@@ -1040,6 +1087,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
+ fd.d.hashMu.RUnlock()
if err != nil {
return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
@@ -1065,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t
func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
}
+
+// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
+type FileReadWriteSeeker struct {
+ FD *vfs.FileDescription
+ Ctx context.Context
+ ROpts vfs.ReadOptions
+ WOpts vfs.WriteOptions
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
+// Read implements io.ReadWriteSeeker.Read.
+func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
+}
+
+// Seek implements io.ReadWriteSeeker.Seek.
+func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
+ return f.FD.Seek(f.Ctx, offset, int32(whence))
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
+// Write implements io.ReadWriteSeeker.Write.
+func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
+ buf := usermem.BytesIOSequence(p)
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
+}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 7196e74eb..30d8b4355 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -35,16 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// rootMerkleFilename is the name of the root Merkle tree file.
-const rootMerkleFilename = "root.verity"
+const (
+ // rootMerkleFilename is the name of the root Merkle tree file.
+ rootMerkleFilename = "root.verity"
+ // maxDataSize is the maximum data size of a test file.
+ maxDataSize = 100000
+)
+
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
-// maxDataSize is the maximum data size written to the file for test.
-const maxDataSize = 100000
+func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry {
+ t.Helper()
+ d, ok := vd.Dentry().Impl().(*dentry)
+ if !ok {
+ t.Fatalf("can't assert %T as a *dentry", vd)
+ }
+ return d
+}
+
+// dentryFromFD returns the dentry corresponding to fd.
+func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry {
+ t.Helper()
+ f, ok := fd.Impl().(*fileDescription)
+ if !ok {
+ t.Fatalf("can't assert %T as a *fileDescription", fd)
+ }
+ return f.d
+}
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ t.Helper()
k, err := testutil.Boot()
if err != nil {
t.Fatalf("testutil.Boot: %v", err)
@@ -69,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
- LowerName: "tmpfs",
- Alg: hashAlg,
- AllowRuntimeEnable: true,
- NoCrashOnVerificationFailure: true,
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ Alg: hashAlg,
+ AllowRuntimeEnable: true,
+ Action: ErrorOnViolation,
},
},
})
@@ -92,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
t.Fatalf("testutil.CreateTask: %v", err)
}
- t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
@@ -100,21 +122,97 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
return vfsObj, root, task, nil
}
-// newFileFD creates a new file in the verity mount, and returns the FD. The FD
-// points to a file that has random data generated.
-func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
- creds := auth.CredentialsFromContext(ctx)
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+// openVerityAt opens a verity file.
+//
+// TODO(chongc): release reference from opening the file when done.
+func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: vd,
+ Start: vd,
+ Path: fspath.Parse(path),
+ }, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: mode,
+ })
+}
- // Create the file in the underlying file system.
- lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(filePath),
+// openLowerAt opens the file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
+func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(path),
+ }, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: mode,
+ })
+}
+
+// openLowerMerkleAt opens the Merkle file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
+func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
}, &vfs.OpenOptions{
- Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
- Mode: linux.ModeRegular | mode,
+ Flags: flags,
+ Mode: mode,
+ })
+}
+
+// unlinkLowerAt deletes the file in the underlying file system.
+func (d *dentry) unlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error {
+ return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(path),
})
+}
+
+// unlinkLowerMerkleAt deletes the Merkle file in the underlying file system.
+func (d *dentry) unlinkLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error {
+ return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + path),
+ })
+}
+
+// renameLowerAt renames file name to newName in the underlying file system.
+func (d *dentry) renameLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error {
+ return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(name),
+ }, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(newName),
+ }, &vfs.RenameOptions{})
+}
+
+// renameLowerMerkleAt renames Merkle file name to newName in the underlying
+// file system.
+func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error {
+ return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + name),
+ }, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + newName),
+ }, &vfs.RenameOptions{})
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ // Create the file in the underlying file system.
+ lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
if err != nil {
return nil, 0, err
}
@@ -137,20 +235,24 @@ func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.Virt
lowerFD.DecRef(ctx)
// Now open the verity file descriptor.
- fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filePath),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular | mode,
- })
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
return fd, dataSize, err
}
-// corruptRandomBit randomly flips a bit in the file represented by fd.
-func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
- // Flip a random bit in the underlying file.
+// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD.
+func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) {
+ // Create the file in the underlying file system.
+ _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ if err != nil {
+ return nil, err
+ }
+ // Now open the verity file descriptor.
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
+ return fd, err
+}
+
+// flipRandomBit randomly flips a bit in the file represented by fd.
+func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
randomPos := int64(rand.Intn(size))
byteToModify := make([]byte, 1)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
@@ -163,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
-var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ t.Helper()
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("enable verity: %v", err)
+ }
+}
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
@@ -175,30 +284,18 @@ func TestOpen(t *testing.T) {
}
filename := "verity-test-file"
- if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Ensure that the corresponding Merkle tree file is created.
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
+ if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
}
// Ensure the root merkle tree file is created.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + rootMerkleFilename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
+ if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
}
}
@@ -214,17 +311,13 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
@@ -248,17 +341,13 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
@@ -272,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
}
+// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity
+// file succeeds after enabling verity for it.
+func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-empty-file"
+ fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newEmptyFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ enableVerity(ctx, t, fd)
+
+ var buf []byte
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != 0 {
+ t.Errorf("fd.Read got read length %d, expected 0", n)
+ }
+ }
+}
+
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
@@ -282,27 +401,16 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Ensure reopening the verity enabled file succeeds.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != nil {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("reopen enabled file failed: %v", err)
}
}
@@ -317,43 +425,24 @@ func TestOpenNonexistentFile(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
+ parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Ensure open an unexpected file in the parent directory fails with
// ENOENT rather than verification failure.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename + "abc"),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.ENOENT {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT {
t.Errorf("OpenAt unexpected error: %v", err)
}
}
@@ -368,33 +457,22 @@ func TestPReadModifiedFileFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -415,33 +493,22 @@ func TestReadModifiedFileFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -462,27 +529,16 @@ func TestModifiedMerkleFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerMerkleFD that's read/writable.
- lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
-
- lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerMerkleVD,
- Start: lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -493,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) {
t.Errorf("lowerMerkleFD.Stat: %v", err)
}
- if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from a file with modified Merkle tree fails.
buf := make([]byte, size)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
t.Fatalf("fd.PRead succeeded with modified Merkle file")
}
}
@@ -517,42 +572,23 @@ func TestModifiedParentMerkleFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
+ parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
-
- parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: parentLowerMerkleVD,
- Start: parentLowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -572,21 +608,14 @@ func TestModifiedParentMerkleFails(t *testing.T) {
if err != nil {
t.Fatalf("Failed convert size to int: %v", err)
}
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
parentLowerMerkleFD.DecRef(ctx)
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err == nil {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err == nil {
t.Errorf("OpenAt file with modified parent Merkle succeeded")
}
}
@@ -602,18 +631,13 @@ func TestUnmodifiedStatSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
+ // Enable verity on the file and confirm that stat succeeds.
+ enableVerity(ctx, t, fd)
if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
t.Errorf("fd.Stat: %v", err)
}
@@ -630,17 +654,13 @@ func TestModifiedStatFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
lowerFD := fd.Impl().(*fileDescription).lowerFD
// Change the stat of the underlying file, and check that stat fails.
@@ -663,73 +683,57 @@ func TestModifiedStatFails(t *testing.T) {
// and/or the corresponding Merkle tree file fails with the verity error.
func TestOpenDeletedFileFails(t *testing.T) {
testCases := []struct {
+ name string
// The original file is removed if changeFile is true.
changeFile bool
// The Merkle tree file is removed if changeMerkleFile is true.
changeMerkleFile bool
}{
{
+ name: "FileOnly",
changeFile: true,
changeMerkleFile: false,
},
{
+ name: "MerkleOnly",
changeFile: false,
changeMerkleFile: true,
},
{
+ name: "FileAndMerkle",
changeFile: true,
changeMerkleFile: true,
},
}
for _, tc := range testCases {
- t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) {
+ t.Run(tc.name, func(t *testing.T) {
vfsObj, root, ctx, err := newVerityRoot(t, SHA256)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
- rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
if tc.changeFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
@@ -740,82 +744,58 @@ func TestOpenDeletedFileFails(t *testing.T) {
// and/or the corresponding Merkle tree file fails with the verity error.
func TestOpenRenamedFileFails(t *testing.T) {
testCases := []struct {
+ name string
// The original file is renamed if changeFile is true.
changeFile bool
// The Merkle tree file is renamed if changeMerkleFile is true.
changeMerkleFile bool
}{
{
+ name: "FileOnly",
changeFile: true,
changeMerkleFile: false,
},
{
+ name: "MerkleOnly",
changeFile: false,
changeMerkleFile: true,
},
{
+ name: "FileAndMerkle",
changeFile: true,
changeMerkleFile: true,
},
}
for _, tc := range testCases {
- t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) {
+ t.Run(tc.name, func(t *testing.T) {
vfsObj, root, ctx, err := newVerityRoot(t, SHA256)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
- rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
newFilename := "renamed-test-file"
if tc.changeFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(newFilename),
- }, &vfs.RenameOptions{}); err != nil {
+ if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("RenameAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + newFilename),
- }, &vfs.RenameOptions{}); err != nil {
+ if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
diff --git a/pkg/sentry/fsmetric/BUILD b/pkg/sentry/fsmetric/BUILD
new file mode 100644
index 000000000..4e86fbdd8
--- /dev/null
+++ b/pkg/sentry/fsmetric/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsmetric",
+ srcs = ["fsmetric.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/metric"],
+)
diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go
new file mode 100644
index 000000000..7e535b527
--- /dev/null
+++ b/pkg/sentry/fsmetric/fsmetric.go
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsmetric defines filesystem metrics that are used by both VFS1 and
+// VFS2.
+//
+// TODO(gvisor.dev/issue/1624): Once VFS1 is deleted, inline these metrics into
+// VFS2.
+package fsmetric
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/metric"
+)
+
+// RecordWaitTime enables the ReadWait, GoferReadWait9P, GoferReadWaitHost, and
+// TmpfsReadWait metrics. Enabling this comes at a CPU cost due to performing
+// three clock reads per read call.
+//
+// Note that this is only performed in the direct read path, and may not be
+// consistently applied for other forms of reads, such as splice.
+var RecordWaitTime = false
+
+// Metrics that apply to all filesystems.
+var (
+ Opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
+ Reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
+ ReadWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
+)
+
+// Metrics that only apply to fs/gofer and fsimpl/gofer.
+var (
+ GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.")
+ GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.")
+ GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.")
+ GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
+ GoferReadWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
+ GoferReadsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
+ GoferReadWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
+)
+
+// Metrics that only apply to fs/tmpfs and fsimpl/tmpfs.
+var (
+ TmpfsOpensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
+ TmpfsOpensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
+ TmpfsReads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
+ TmpfsReadWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
+)
+
+// StartReadWait indicates the beginning of a file read.
+func StartReadWait() time.Time {
+ if !RecordWaitTime {
+ return time.Time{}
+ }
+ return time.Now()
+}
+
+// FinishReadWait indicates the end of a file read whose time is accounted by
+// m. start must be the value returned by the corresponding call to
+// StartReadWait.
+//
+// FinishReadWait is marked nosplit for performance since it's often called
+// from defer statements, which prevents it from being inlined
+// (https://github.com/golang/go/issues/38471).
+//go:nosplit
+func FinishReadWait(m *metric.Uint64Metric, start time.Time) {
+ if !RecordWaitTime {
+ return
+ }
+ m.IncrementBy(uint64(time.Since(start).Nanoseconds()))
+}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 15519f0df..61aeca044 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
//
// Callback is called when one of the files we're polling becomes ready. It
// moves said file to the readyList if it's currently in the waiting list.
-func (p *pollEntry) Callback(*waiter.Entry) {
+func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) {
e := p.epoll
e.listsMu.Lock()
@@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
f.EventRegister(&entry.waiter, entry.mask)
// Check if the file happens to already be in a ready state.
- ready := f.Readiness(entry.mask) & entry.mask
- if ready != 0 {
- entry.Callback(&entry.waiter)
+ if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 {
+ entry.Callback(&entry.waiter, ready)
}
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 2b3955598..f855f038b 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,11 +8,13 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 153d2cd9b..b66d61c6f 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,22 +17,45 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new fs.FileAsync.
-func New() fs.FileAsync {
- return &FileAsync{}
+// Table to convert waiter event masks into si_band siginfo codes.
+// Taken from fs/fcntl.c:band_table.
+var bandTable = map[waiter.EventMask]int64{
+ // POLL_IN
+ waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM,
+ // POLL_OUT
+ waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND,
+ // POLL_ERR
+ waiter.EventErr: linux.EPOLLERR,
+ // POLL_PRI
+ waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND,
+ // POLL_HUP
+ waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR,
}
-// NewVFS2 creates a new vfs.FileAsync.
-func NewVFS2() vfs.FileAsync {
- return &FileAsync{}
+// New returns a function that creates a new fs.FileAsync with the given file
+// descriptor.
+func New(fd int) func() fs.FileAsync {
+ return func() fs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
+}
+
+// NewVFS2 returns a function that creates a new vfs.FileAsync with the given
+// file descriptor.
+func NewVFS2(fd int) func() vfs.FileAsync {
+ return func() vfs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
}
// FileAsync sends signals when the registered file is ready for IO.
@@ -42,6 +65,12 @@ type FileAsync struct {
// e is immutable after first use (which is protected by mu below).
e waiter.Entry
+ // fd is the file descriptor to notify about.
+ // It is immutable, set at allocation time. This matches Linux semantics in
+ // fs/fcntl.c:fasync_helper.
+ // The fd value is passed to the signal recipient in siginfo.si_fd.
+ fd int
+
// regMu protects registeration and unregistration actions on e.
//
// regMu must be held while registration decisions are being made
@@ -56,6 +85,10 @@ type FileAsync struct {
mu sync.Mutex `state:"nosave"`
requester *auth.Credentials
registered bool
+ // signal is the signal to deliver upon I/O being available.
+ // The default value ("zero signal") means the default SIGIO signal will be
+ // delivered.
+ signal linux.Signal
// Only one of the following is allowed to be non-nil.
recipientPG *kernel.ProcessGroup
@@ -64,10 +97,10 @@ type FileAsync struct {
}
// Callback sends a signal.
-func (a *FileAsync) Callback(e *waiter.Entry) {
+func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
a.mu.Lock()
+ defer a.mu.Unlock()
if !a.registered {
- a.mu.Unlock()
return
}
t := a.recipientT
@@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) {
}
if t == nil {
// No recipient has been registered.
- a.mu.Unlock()
return
}
c := t.Credentials()
// Logic from sigio_perm in fs/fcntl.c.
- if a.requester.EffectiveKUID == 0 ||
+ permCheck := (a.requester.EffectiveKUID == 0 ||
a.requester.EffectiveKUID == c.SavedKUID ||
a.requester.EffectiveKUID == c.RealKUID ||
a.requester.RealKUID == c.SavedKUID ||
- a.requester.RealKUID == c.RealKUID {
- t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ a.requester.RealKUID == c.RealKUID)
+ if !permCheck {
+ return
}
- a.mu.Unlock()
+ signalInfo := &arch.SignalInfo{
+ Signo: int32(linux.SIGIO),
+ Code: arch.SignalInfoKernel,
+ }
+ if a.signal != 0 {
+ signalInfo.Signo = int32(a.signal)
+ signalInfo.SetFD(uint32(a.fd))
+ var band int64
+ for m, bandCode := range bandTable {
+ if m&mask != 0 {
+ band |= bandCode
+ }
+ }
+ signalInfo.SetBand(band)
+ }
+ t.SendSignal(signalInfo)
}
// Register sets the file which will be monitored for IO events.
@@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() {
a.recipientTG = nil
a.recipientPG = nil
}
+
+// Signal returns which signal will be sent to the signal recipient.
+// A value of zero means the signal to deliver wasn't customized, which means
+// the default signal (SIGIO) will be delivered.
+func (a *FileAsync) Signal() linux.Signal {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.signal
+}
+
+// SetSignal overrides which signal to send when I/O is available.
+// The default behavior can be reset by specifying signal zero, which means
+// to send SIGIO.
+func (a *FileAsync) SetSignal(signal linux.Signal) error {
+ if signal != 0 && !signal.IsValid() {
+ return syserror.EINVAL
+ }
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.signal = signal
+ return nil
+}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 470d8bf83..f17f9c59c 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
panic("VFS1 and VFS2 files set")
}
- slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
// Grow the table as required.
- if last := int32(len(slice)); fd >= last {
+ if last := int32(len(*slicePtr)); fd >= last {
end := fd + 1
if end < 2*last {
end = 2 * last
}
- slice = append(slice, make([]unsafe.Pointer, end-last)...)
- atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+ newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...)
+ slicePtr = &newSlice
+ atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr))
}
+ slice := *slicePtr
+
var desc *descriptor
if file != nil || fileVFS2 != nil {
desc = &descriptor{
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 2cdcdfc1f..b8627a54f 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -214,9 +214,11 @@ type Kernel struct {
// netlinkPorts manages allocation of netlink socket port IDs.
netlinkPorts *port.Manager
- // saveErr is the error causing the sandbox to exit during save, if
- // any. It is protected by extMu.
- saveErr error `state:"nosave"`
+ // saveStatus is nil if the sandbox has not been saved, errSaved or
+ // errAutoSaved if it has been saved successfully, or the error causing the
+ // sandbox to exit during save.
+ // It is protected by extMu.
+ saveStatus error `state:"nosave"`
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
@@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager {
return k.netlinkPorts
}
-// SaveError returns the sandbox error that caused the kernel to exit during
-// save.
-func (k *Kernel) SaveError() error {
+var (
+ errSaved = errors.New("sandbox has been successfully saved")
+ errAutoSaved = errors.New("sandbox has been successfully auto-saved")
+)
+
+// SaveStatus returns the sandbox save status. If it was saved successfully,
+// autosaved indicates whether save was triggered by autosave. If it was not
+// saved successfully, err indicates the sandbox error that caused the kernel to
+// exit during save.
+func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ switch k.saveStatus {
+ case nil:
+ return false, false, nil
+ case errSaved:
+ return true, false, nil
+ case errAutoSaved:
+ return true, true, nil
+ default:
+ return false, false, k.saveStatus
+ }
+}
+
+// SetSaveSuccess sets the flag indicating that save completed successfully, if
+// no status was already set.
+func (k *Kernel) SetSaveSuccess(autosave bool) {
k.extMu.Lock()
defer k.extMu.Unlock()
- return k.saveErr
+ if k.saveStatus == nil {
+ if autosave {
+ k.saveStatus = errAutoSaved
+ } else {
+ k.saveStatus = errSaved
+ }
+ }
}
// SetSaveError sets the sandbox error that caused the kernel to exit during
@@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error {
func (k *Kernel) SetSaveError(err error) {
k.extMu.Lock()
defer k.extMu.Unlock()
- if k.saveErr == nil {
- k.saveErr = err
+ if k.saveStatus == nil {
+ k.saveStatus = err
}
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1abfe2201..cef58a590 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) {
Signo: int32(linux.SIGTRAP),
Code: code,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index b99c0bffa..db01e4a97 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -29,17 +29,17 @@ import (
)
const (
- valueMax = 32767 // SEMVMX
+ // Maximum semaphore value.
+ valueMax = linux.SEMVMX
- // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
- semaphoresMax = 32000
+ // Maximum number of semaphore sets.
+ setsMax = linux.SEMMNI
- // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
- setsMax = 32000
+ // Maximum number of semaphroes in a semaphore set.
+ semsMax = linux.SEMMSL
- // semaphoresTotalMax is "system-wide limit on the number of semaphores"
- // (SEMMNS = SEMMNI*SEMMSL).
- semaphoresTotalMax = 1024000000
+ // Maximum number of semaphores in all semaphroe sets.
+ semsTotalMax = linux.SEMMNS
)
// Registry maintains a set of semaphores that can be found by key or ID.
@@ -52,6 +52,9 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
semaphores map[int32]*Set
lastIDUsed int32
+ // indexes maintains a mapping between a set's index in virtual array and
+ // its identifier.
+ indexes map[int32]int32
}
// Set represents a set of semaphores that can be operated atomically.
@@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
userNS: userNS,
semaphores: make(map[int32]*Set),
+ indexes: make(map[int32]int32),
}
}
@@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// be found. If exclusive is true, it fails if a set with the same key already
// exists.
func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
- if nsems < 0 || nsems > semaphoresMax {
+ if nsems < 0 || nsems > semsMax {
return nil, syserror.EINVAL
}
@@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Apply system limits.
+ //
+ // Map semaphores and map indexes in a registry are of the same size,
+ // check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
return nil, syserror.EINVAL
}
- if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ if r.totalSems() > int(semsTotalMax-nsems) {
return nil, syserror.EINVAL
}
@@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
return r.newSet(ctx, key, owner, owner, perms, nsems)
}
+// IPCInfo returns information about system-wide semaphore limits and parameters.
+func (r *Registry) IPCInfo() *linux.SemInfo {
+ return &linux.SemInfo{
+ SemMap: linux.SEMMAP,
+ SemMni: linux.SEMMNI,
+ SemMns: linux.SEMMNS,
+ SemMnu: linux.SEMMNU,
+ SemMsl: linux.SEMMSL,
+ SemOpm: linux.SEMOPM,
+ SemUme: linux.SEMUME,
+ SemUsz: linux.SEMUSZ,
+ SemVmx: linux.SEMVMX,
+ SemAem: linux.SEMAEM,
+ }
+}
+
+// SemInfo returns a seminfo structure containing the same information as
+// for IPC_INFO, except that SemUsz field returns the number of existing
+// semaphore sets, and SemAem field returns the number of existing semaphores.
+func (r *Registry) SemInfo() *linux.SemInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ info := r.IPCInfo()
+ info.SemUsz = uint32(len(r.semaphores))
+ info.SemAem = uint32(r.totalSems())
+
+ return info
+}
+
+// HighestIndex returns the index of the highest used entry in
+// the kernel's array.
+func (r *Registry) HighestIndex() int32 {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // By default, highest used index is 0 even though
+ // there is no semaphroe set.
+ var highestIndex int32
+ for index := range r.indexes {
+ if index > highestIndex {
+ highestIndex = index
+ }
+ }
+ return highestIndex
+}
+
// RemoveID removes set with give 'id' from the registry and marks the set as
// dead. All waiters will be awakened and fail.
func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
@@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
if set == nil {
return syserror.EINVAL
}
+ index, found := r.findIndexByID(id)
+ if !found {
+ // Inconsistent state.
+ panic(fmt.Sprintf("unable to find an index for ID: %d", id))
+ }
set.mu.Lock()
defer set.mu.Unlock()
@@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
}
delete(r.semaphores, set.ID)
+ delete(r.indexes, index)
set.destroy()
return nil
}
@@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File
continue
}
if r.semaphores[id] == nil {
+ index, found := r.findFirstAvailableIndex()
+ if !found {
+ panic("unable to find an available index")
+ }
+ r.indexes[index] = id
r.lastIDUsed = id
r.semaphores[id] = set
set.ID = id
@@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set {
return r.semaphores[id]
}
+// FindByIndex looks up a set given an index.
+func (r *Registry) FindByIndex(index int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ id, present := r.indexes[index]
+ if !present {
+ return nil
+ }
+ return r.semaphores[id]
+}
+
func (r *Registry) findByKey(key int32) *Set {
for _, v := range r.semaphores {
if v.key == key {
@@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set {
return nil
}
+func (r *Registry) findIndexByID(id int32) (int32, bool) {
+ for k, v := range r.indexes {
+ if v == id {
+ return k, true
+ }
+ }
+ return 0, false
+}
+
+func (r *Registry) findFirstAvailableIndex() (int32, bool) {
+ for index := int32(0); index < setsMax; index++ {
+ if _, present := r.indexes[index]; !present {
+ return index, true
+ }
+ }
+ return 0, false
+}
+
func (r *Registry) totalSems() int {
totalSems := 0
for _, v := range r.semaphores {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 80a592c8f..073e14507 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -6,6 +6,9 @@ package(licenses = ["notice"])
go_template_instance(
name = "shm_refs",
out = "shm_refs.go",
+ consts = {
+ "enableLogging": "true",
+ },
package = "shm",
prefix = "Shm",
template = "//pkg/refsvfs2:refs_template",
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index e8cce37d0..2488ae7d5 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 78f718cfe..884966120 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index a83ce219c..3fee7aa68 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -75,6 +75,12 @@ func (s *syslog) Log() []byte {
"Checking naughty and nice process list...", // Check it up to twice.
"Granting licence to kill(2)...", // British spelling for British movie.
"Letting the watchdogs out...",
+ "Conjuring /dev/null black hole...",
+ "Adversarially training Redcode AI...",
+ "Singleplexing /dev/ptmx...",
+ "Recruiting cron-ies...",
+ "Verifying that no non-zero bytes made their way into /dev/zero...",
+ "Accelerating teletypewriter to 9600 baud...",
}
selectMessage := func() string {
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c5137c282..16986244c 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -368,8 +368,8 @@ func (t *Task) exitChildren() {
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- siginfo.SetPid(int32(c.tg.pidns.tids[t]))
- siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ siginfo.SetPID(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
c.tg.signalHandlers.mu.Lock()
c.sendSignalLocked(siginfo, true /* group */)
c.tg.signalHandlers.mu.Unlock()
@@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si
info := &arch.SignalInfo{
Signo: int32(sig),
}
- info.SetPid(int32(receiver.tg.pidns.tids[t]))
- info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.tids[t]))
+ info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
info.Code = arch.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 42dd3e278..75af3af79 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
Signo: int32(linux.SIGCHLD),
Code: code,
}
- sigchld.SetPid(int32(t.tg.pidns.tids[target]))
- sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetPID(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
sigchld.SetStatus(status)
// TODO(b/72102453): Set utime, stime.
t.sendSignalLocked(sigchld, true /* group */)
@@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState {
Signo: int32(sig),
Code: t.ptraceCode,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
} else {
t.ptraceCode = int32(sig)
t.ptraceSiginfo = nil
@@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if parent == nil {
// Tracer has detached and t was created by Kernel.CreateProcess().
// Pretend the parent is in an ancestor PID + user namespace.
- info.SetPid(0)
- info.SetUid(int32(auth.OverflowUID))
+ info.SetPID(0)
+ info.SetUID(int32(auth.OverflowUID))
} else {
- info.SetPid(int32(t.tg.pidns.tids[parent]))
- info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(t.tg.pidns.tids[parent]))
+ info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
}
}
t.tg.signalHandlers.mu.Lock()
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 7fd77925f..49e21026e 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
// Translations must be contiguous and in increasing order of
// Translation.Source.
if i > 0 && ts[i-1].Source.End != t.Source.Start {
- return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t)
}
// At least part of each Translation must be required.
if t.Source.Intersect(required).Length() == 0 {
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 4c8cd38ed..5ab2ef79f 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -36,12 +36,12 @@ type aioManager struct {
contexts map[uint64]*AIOContext
}
-func (a *aioManager) destroy() {
- a.mu.Lock()
- defer a.mu.Unlock()
+func (mm *MemoryManager) destroyAIOManager(ctx context.Context) {
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
- for _, ctx := range a.contexts {
- ctx.destroy()
+ for id := range mm.aioManager.contexts {
+ mm.destroyAIOContextLocked(ctx, id)
}
}
@@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
// be drained.
//
// Nil is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
- a.mu.Lock()
- defer a.mu.Unlock()
- ctx, ok := a.contexts[id]
+//
+// Precondition: mm.aioManager.mu is locked.
+func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext {
+ aioCtx, ok := mm.aioManager.contexts[id]
if !ok {
return nil
}
- delete(a.contexts, id)
- ctx.destroy()
- return ctx
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ delete(mm.aioManager.contexts, id)
+ aioCtx.destroy()
+ return aioCtx
}
// lookupAIOContext looks up the given context.
@@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() {
}
}
-// Prepare reserves space for a new request, returning true if available.
-// Returns false if the context is busy.
-func (ctx *AIOContext) Prepare() bool {
+// Prepare reserves space for a new request, returning nil if available.
+// Returns EAGAIN if the context is busy and EINVAL if the context is dead.
+func (ctx *AIOContext) Prepare() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
+ if ctx.dead {
+ // Context died after the caller looked it up.
+ return syserror.EINVAL
+ }
if ctx.outstanding >= ctx.maxOutstanding {
- return false
+ // Context is busy.
+ return syserror.EAGAIN
}
ctx.outstanding++
- return true
+ return nil
}
// PopRequest pops a completed request if available, this function does not do
@@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
// DestroyAIOContext destroys an asynchronous I/O context. It returns the
// destroyed context. nil if the context does not exist.
func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
- if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ if !mm.isValidAddr(ctx, id) {
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
-
- return mm.aioManager.destroyAIOContext(id)
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
+ return mm.destroyAIOContextLocked(ctx, id)
}
// LookupAIOContext looks up the given context. It returns false if the context
@@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC
return nil, false
}
- // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
- // from id).
- var buf [4]byte
- _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
- if err != nil {
+ // Protect against 'id' that is inaccessible.
+ if !mm.isValidAddr(ctx, id) {
return nil, false
}
return aioCtx, true
}
+
+// isValidAddr determines if the address `id` is valid. (Linux also reads 4
+// bytes from id).
+func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool {
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ return err == nil
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index 3dabac1af..e8931922f 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -15,6 +15,6 @@
package mm
// afterLoad is invoked by stateify.
-func (a *AIOContext) afterLoad() {
- a.requestReady = make(chan struct{}, 1)
+func (ctx *AIOContext) afterLoad() {
+ ctx.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 09dbc06a4..120707429 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
}
- mm.aioManager.destroy()
+ mm.destroyAIOManager(ctx)
mm.metadataMu.Lock()
exe := mm.executable
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index acac3d357..bc53bd41e 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) {
t.Errorf("CopyOut got %d want 1", n)
}
}
+
+// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be
+// prepared after destruction.
+func TestAIOPrepareAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ aioCtx, ok := mm.LookupAIOContext(ctx, id)
+ if !ok {
+ t.Fatalf("AIOContext not found")
+ }
+ mm.DestroyAIOContext(ctx, id)
+
+ // Prepare should fail because aioCtx should be destroyed.
+ if err := aioCtx.Prepare(); err != syserror.EINVAL {
+ t.Errorf("aioCtx.Prepare got err %v want nil", err)
+ } else if err == nil {
+ aioCtx.CancelPendingRequest()
+ }
+}
+
+// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be
+// looked up after memory manager is destroyed.
+func TestAIOLookupAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ mm.DecUsers(ctx)
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ mm.DecUsers(ctx) // This destroys the AIOContext manager.
+
+ if _, ok := mm.LookupAIOContext(ctx, id); ok {
+ t.Errorf("AIOContext found even after AIOContext manager is destroyed")
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 7c297fb9e..d99be7f46 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
}
if f.opts.ManualZeroing {
- if err := f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
- }
- }); err != nil {
+ if err := f.manuallyZero(fr); err != nil {
return memmap.FileRange{}, err
}
}
@@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
panic(fmt.Sprintf("invalid range: %v", fr))
}
+ if f.opts.ManualZeroing {
+ // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in
+ // effect.
+ if err := f.manuallyZero(fr); err != nil {
+ return err
+ }
+ } else {
+ if err := f.decommitFile(fr); err != nil {
+ return err
+ }
+ }
+
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error {
+ return f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+}
+
+func (f *MemoryFile) decommitFile(fr memmap.FileRange) error {
// "After a successful call, subsequent reads from this range will
// return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
// FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
- err := syscall.Fallocate(
+ return syscall.Fallocate(
int(f.file.Fd()),
_FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
int64(fr.Start),
int64(fr.Length()))
- if err != nil {
- return err
- }
- f.markDecommitted(fr)
- return nil
}
func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
@@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- if err := f.Decommit(fr); err != nil {
- log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
- // Zero the pages manually. This won't reduce memory usage, but at
- // least ensures that the pages will be zero when reallocated.
- f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
+ // If ManualZeroing is in effect, pages will be zeroed on allocation
+ // and may not be freed by decommitFile, so calling decommitFile is
+ // unnecessary.
+ if !f.opts.ManualZeroing {
+ if err := f.decommitFile(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ if err := f.manuallyZero(fr); err != nil {
+ panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err))
}
- })
- // Pretend the pages were decommitted even though they weren't,
- // since the memory accounting implementation has no idea how to
- // deal with this.
- f.markDecommitted(fr)
+ }
}
+ f.markDecommitted(fr)
f.markReclaimed(fr)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index acad4c793..f8ccb7430 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) {
}
}
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ throw("run failed: ENOSYS")
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
@@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool {
}
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ c.die(bluepillArchContext(context), "unknown")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 965ad66b5..1f09813ba 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -42,6 +42,13 @@ var (
sErrEsr: _ESR_ELx_SERR_NMI,
},
}
+
+ // vcpuExtDabt is the event of ext_dabt.
+ vcpuExtDabt = kvmVcpuEvents{
+ exception: exception{
+ extDabtPending: 1,
+ },
+ }
)
// getTLS returns the value of TPIDR_EL0 register.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 9433d4da5..4d912769a 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) {
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
- throw("sErr injection failed")
+ throw("bounce sErr injection failed")
}
}
@@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillSigBus(c *vCPU) {
+ // Host must support ARM64_HAS_RAS_EXTN.
if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
- throw("sErr injection failed")
+ if errno == syscall.EINVAL {
+ throw("No ARM64_HAS_RAS_EXTN feature in host.")
+ }
+ throw("nmi sErr injection failed")
}
}
+// bluepillExtDabt is reponsible for injecting external data abort.
+//
+//go:nosplit
+func bluepillExtDabt(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 {
+ throw("ext_dabt injection failed")
+ }
+}
+
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ bluepillExtDabt(c)
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ switch c.runData.exitReason {
+ case _KVM_EXIT_ARM_NISV:
+ bluepillExtDabt(c)
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 75085ac6a..8c5369377 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) {
// mode and have interrupts disabled.
bluepillSigBus(c)
continue // Rerun vCPU.
+ case syscall.ENOSYS:
+ bluepillHandleEnosys(c)
+ continue
default:
throw("run failed")
}
@@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "entry failed")
return
default:
- c.die(bluepillArchContext(context), "unknown")
+ bluepillArchHandleExit(c, context)
return
}
}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 0b06a923a..9db1db4e9 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -47,10 +47,11 @@ type userRegs struct {
}
type exception struct {
- sErrPending uint8
- sErrHasEsr uint8
- pad [6]uint8
- sErrEsr uint64
+ sErrPending uint8
+ sErrHasEsr uint8
+ extDabtPending uint8
+ pad [5]uint8
+ sErrEsr uint64
}
type kvmVcpuEvents struct {
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 6abaa21c4..2492d57be 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -56,6 +56,7 @@ const (
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
_KVM_EXIT_SYSTEM_EVENT = 0x18
+ _KVM_EXIT_ARM_NISV = 0x1c
)
// KVM capability options.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index fd92c3873..3f5be276b 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -263,13 +263,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return usermem.NoAccess, platform.ErrContextInterrupt
case ring0.El0SyncUndef:
return c.fault(int32(syscall.SIGILL), info)
- case ring0.El1SyncUndef:
- *info = arch.SignalInfo{
- Signo: int32(syscall.SIGILL),
- Code: 1, // ILL_ILLOPC (illegal opcode).
- }
- info.SetAddr(switchOpts.Registers.Pc) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
default:
panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
}
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index f56aa3b79..571bfcc2e 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -18,8 +18,8 @@
//
// In a nutshell, it works as follows:
//
-// The creation of a new address space creates a new child processes with a
-// single thread which is traced by a single goroutine.
+// The creation of a new address space creates a new child process with a single
+// thread which is traced by a single goroutine.
//
// A context is just a collection of temporary variables. Calling Switch on a
// context does the following:
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 812ab80ef..aacd7ce70 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// facilitate vsyscall emulation. See patchSignalInfo.
patchSignalInfo(regs, &c.signalInfo)
return false
- } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) {
// The signal was generated by this process. That means
// that it was an interrupt or something else that we
// should bail for. Note that we ignore signals
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 679b287c3..2852b7387 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "arch_genrule", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
@@ -39,19 +39,19 @@ go_template_instance(
template = ":defs_arm64",
)
-genrule(
+arch_genrule(
name = "entry_impl_amd64",
srcs = ["entry_amd64.s"],
outs = ["entry_impl_amd64.s"],
- cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
-genrule(
+arch_genrule(
name = "entry_impl_arm64",
srcs = ["entry_arm64.s"],
outs = ["entry_impl_arm64.s"],
- cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
@@ -72,7 +72,6 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
- "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 327d48465..c51df2811 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -90,6 +90,7 @@ const (
El0SyncIa
El0SyncFpsimdAcc
El0SyncSveAcc
+ El0SyncFpsimdExc
El0SyncSys
El0SyncSpPc
El0SyncUndef
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index f489ad352..cf0bf3528 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -132,40 +132,6 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-// NOP-s
-#define nop31Instructions() \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f;
-
#define ESR_ELx_EC_UNKNOWN (0x00)
#define ESR_ELx_EC_WFx (0x01)
/* Unallocated EC: 0x02 */
@@ -324,6 +290,18 @@
MOVD CPU_TTBR0_KVM(from), RSV_REG; \
MSR RSV_REG, TTBR0_EL1;
+TEXT ·EnableVFP(SB),NOSPLIT,$0
+ MOVD $FPEN_ENABLE, R0
+ WORD $0xd5181040 //MSR R0, CPACR_EL1
+ ISB $15
+ RET
+
+TEXT ·DisableVFP(SB),NOSPLIT,$0
+ MOVD $0, R0
+ WORD $0xd5181040 //MSR R0, CPACR_EL1
+ ISB $15
+ RET
+
#define VFP_ENABLE \
MOVD $FPEN_ENABLE, R0; \
WORD $0xd5181040; \ //MSR R0, CPACR_EL1
@@ -370,12 +348,12 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
-// EXCEPTION_WITH_ERROR is a common exception handler function.
-#define EXCEPTION_WITH_ERROR(user, vector) \
+// EXCEPTION_EL0 is a common el0 exception handler function.
+#define EXCEPTION_EL0(vector) \
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
WORD $0xd538601a; \ //MRS FAR_EL1, R26
MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
- MOVD $user, R3; \
+ MOVD $1, R3; \
MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
MOVD $vector, R3; \
MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
@@ -383,6 +361,12 @@
MOVD R3, CPU_ERROR_CODE(RSV_REG); \
B ·kernelExitToEl1(SB);
+// EXCEPTION_EL1 is a common el1 exception handler function.
+#define EXCEPTION_EL1(vector) \
+ MOVD $vector, R3; \
+ MOVD R3, 8(RSP); \
+ B ·HaltEl1ExceptionAndResume(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -430,6 +414,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
CALL ·kernelSyscall(SB) // Call the trampoline.
B ·kernelExitToEl1(SB) // Resume.
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ MOVD vector+0(FP), R3
+ MOVD R3, 16(RSP) // Second argument (vector).
+ CALL ·kernelException(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
// Shutdown stops the guest.
TEXT ·Shutdown(SB),NOSPLIT,$0
// PSCI EVENT.
@@ -592,39 +586,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
B el1_invalid
el1_da:
+ EXCEPTION_EL1(El1SyncDa)
el1_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·HaltAndResume(SB)
-
+ EXCEPTION_EL1(El1SyncIa)
el1_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncSpPc)
el1_undef:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncUndef)
el1_svc:
- MOVD $0, CPU_ERROR_CODE(RSV_REG)
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
B ·HaltEl1SvcAndResume(SB)
-
el1_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncDbg)
el1_fpsimd_acc:
VFP_ENABLE
B ·kernelExitToEl1(SB) // Resume.
-
el1_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL1(El1SyncInv)
// El1_irq is the handler for El1_irq.
TEXT ·El1_irq(SB),NOSPLIT,$0
@@ -680,28 +657,21 @@ el0_svc:
el0_da:
el0_ia:
- EXCEPTION_WITH_ERROR(1, PageFault)
-
+ EXCEPTION_EL0(PageFault)
el0_fpsimd_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdAcc)
el0_sve_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSveAcc)
el0_fpsimd_exc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdExc)
el0_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSpPc)
el0_undef:
- EXCEPTION_WITH_ERROR(1, El0SyncUndef)
-
+ EXCEPTION_EL0(El0SyncUndef)
el0_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncDbg)
el0_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL0(El0SyncInv)
TEXT ·El0_irq(SB),NOSPLIT,$0
B ·Shutdown(SB)
@@ -760,79 +730,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
// Vectors implements exception vector table.
+// The start address of exception vector table should be 11-bits aligned.
+// For detail, please refer to arm developer document:
+// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table
+// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S
TEXT ·Vectors(SB),NOSPLIT,$0
+ PCALIGN $2048
B ·El1_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error_invalid(SB)
- nop31Instructions()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
- // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- WORD $0xd503201f //nop
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 9742308d8..a9703baf6 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,6 +24,9 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
+ # Use the libc malloc to avoid any extra dependencies. This is required to
+ # pass the sentry deps test.
+ system_malloc = True,
visibility = [
"//pkg/sentry/platform/kvm:__pkg__",
"//pkg/sentry/platform/ring0:__pkg__",
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index c1c808b96..c05284641 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -24,6 +24,10 @@ func HaltAndResume()
//go:nosplit
func HaltEl1SvcAndResume()
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+//go:nosplit
+func HaltEl1ExceptionAndResume()
+
// init initializes architecture-specific state.
func (k *Kernel) init(maxCPUs int) {
}
@@ -49,6 +53,12 @@ func IsCanonical(addr uint64) bool {
return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
}
+// SwitchToUser performs an eret.
+//
+// The return value is the exception vector.
+//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
@@ -61,11 +71,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(PsrFlagsClear)
regs.Pstate |= UserFlagsSet
+ EnableVFP()
LoadFloatingPoint(switchOpts.FloatingPointState)
kernelExitToEl0()
SaveFloatingPoint(switchOpts.FloatingPointState)
+ DisableVFP()
vector = c.vecCode
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index bf1c655f4..a490bf3af 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -34,13 +34,13 @@ func FlushTlbAll()
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
-// FPCR returns the value of FPCR register.
+// GetFPCR returns the value of FPCR register.
func GetFPCR() (value uintptr)
// SetFPCR writes the FPCR value.
func SetFPCR(value uintptr)
-// FPSR returns the value of FPSR register.
+// GetFPSR returns the value of FPSR register.
func GetFPSR() (value uintptr)
// SetFPSR writes the FPSR value.
@@ -59,9 +59,13 @@ func LoadFloatingPoint(*byte)
// SaveFloatingPoint saves floating point state.
func SaveFloatingPoint(*byte)
+// EnableVFP enables fpsimd.
+func EnableVFP()
+
+// DisableVFP disables fpsimd.
+func DisableVFP()
+
// Init sets function pointers based on architectural features.
//
// This must be called prior to using ring0.
-func Init() {
- rewriteVectors()
-}
+func Init() {}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 675a8bdb7..e39b32841 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -52,62 +52,47 @@ TEXT ·CPACREL1(SB),NOSPLIT,$0-8
RET
TEXT ·GetFPCR(SB),NOSPLIT,$0-8
- WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD FPCR, R1
MOVD R1, ret+0(FP)
RET
TEXT ·GetFPSR(SB),NOSPLIT,$0-8
- WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD FPSR, R1
MOVD R1, ret+0(FP)
RET
TEXT ·SetFPCR(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R1
- WORD $0xd51b4201 // MSR R1, NZCV
+ MOVD R1, FPCR
RET
TEXT ·SetFPSR(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R1
- WORD $0xd51b4421 // MSR R1, FPSR
+ MOVD R1, FPSR
RET
TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R0
// Skip aarch64_ctx, fpsr, fpcr.
- FMOVD F0, 16*1(R0)
- FMOVD F1, 16*2(R0)
- FMOVD F2, 16*3(R0)
- FMOVD F3, 16*4(R0)
- FMOVD F4, 16*5(R0)
- FMOVD F5, 16*6(R0)
- FMOVD F6, 16*7(R0)
- FMOVD F7, 16*8(R0)
- FMOVD F8, 16*9(R0)
- FMOVD F9, 16*10(R0)
- FMOVD F10, 16*11(R0)
- FMOVD F11, 16*12(R0)
- FMOVD F12, 16*13(R0)
- FMOVD F13, 16*14(R0)
- FMOVD F14, 16*15(R0)
- FMOVD F15, 16*16(R0)
- FMOVD F16, 16*17(R0)
- FMOVD F17, 16*18(R0)
- FMOVD F18, 16*19(R0)
- FMOVD F19, 16*20(R0)
- FMOVD F20, 16*21(R0)
- FMOVD F21, 16*22(R0)
- FMOVD F22, 16*23(R0)
- FMOVD F23, 16*24(R0)
- FMOVD F24, 16*25(R0)
- FMOVD F25, 16*26(R0)
- FMOVD F26, 16*27(R0)
- FMOVD F27, 16*28(R0)
- FMOVD F28, 16*29(R0)
- FMOVD F29, 16*30(R0)
- FMOVD F30, 16*31(R0)
- FMOVD F31, 16*32(R0)
- ISB $15
+ ADD $16, R0, R0
+
+ WORD $0xad000400 // stp q0, q1, [x0]
+ WORD $0xad010c02 // stp q2, q3, [x0, #32]
+ WORD $0xad021404 // stp q4, q5, [x0, #64]
+ WORD $0xad031c06 // stp q6, q7, [x0, #96]
+ WORD $0xad042408 // stp q8, q9, [x0, #128]
+ WORD $0xad052c0a // stp q10, q11, [x0, #160]
+ WORD $0xad06340c // stp q12, q13, [x0, #192]
+ WORD $0xad073c0e // stp q14, q15, [x0, #224]
+ WORD $0xad084410 // stp q16, q17, [x0, #256]
+ WORD $0xad094c12 // stp q18, q19, [x0, #288]
+ WORD $0xad0a5414 // stp q20, q21, [x0, #320]
+ WORD $0xad0b5c16 // stp q22, q23, [x0, #352]
+ WORD $0xad0c6418 // stp q24, q25, [x0, #384]
+ WORD $0xad0d6c1a // stp q26, q27, [x0, #416]
+ WORD $0xad0e741c // stp q28, q29, [x0, #448]
+ WORD $0xad0f7c1e // stp q30, q31, [x0, #480]
RET
@@ -115,39 +100,24 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R0
// Skip aarch64_ctx, fpsr, fpcr.
- FMOVD 16*1(R0), F0
- FMOVD 16*2(R0), F1
- FMOVD 16*3(R0), F2
- FMOVD 16*4(R0), F3
- FMOVD 16*5(R0), F4
- FMOVD 16*6(R0), F5
- FMOVD 16*7(R0), F6
- FMOVD 16*8(R0), F7
- FMOVD 16*9(R0), F8
- FMOVD 16*10(R0), F9
- FMOVD 16*11(R0), F10
- FMOVD 16*12(R0), F11
- FMOVD 16*13(R0), F12
- FMOVD 16*14(R0), F13
- FMOVD 16*15(R0), F14
- FMOVD 16*16(R0), F15
- FMOVD 16*17(R0), F16
- FMOVD 16*18(R0), F17
- FMOVD 16*19(R0), F18
- FMOVD 16*20(R0), F19
- FMOVD 16*21(R0), F20
- FMOVD 16*22(R0), F21
- FMOVD 16*23(R0), F22
- FMOVD 16*24(R0), F23
- FMOVD 16*25(R0), F24
- FMOVD 16*26(R0), F25
- FMOVD 16*27(R0), F26
- FMOVD 16*28(R0), F27
- FMOVD 16*29(R0), F28
- FMOVD 16*30(R0), F29
- FMOVD 16*31(R0), F30
- FMOVD 16*32(R0), F31
- ISB $15
+ ADD $16, R0, R0
+
+ WORD $0xad400400 // ldp q0, q1, [x0]
+ WORD $0xad410c02 // ldp q2, q3, [x0, #32]
+ WORD $0xad421404 // ldp q4, q5, [x0, #64]
+ WORD $0xad431c06 // ldp q6, q7, [x0, #96]
+ WORD $0xad442408 // ldp q8, q9, [x0, #128]
+ WORD $0xad452c0a // ldp q10, q11, [x0, #160]
+ WORD $0xad46340c // ldp q12, q13, [x0, #192]
+ WORD $0xad473c0e // ldp q14, q15, [x0, #224]
+ WORD $0xad484410 // ldp q16, q17, [x0, #256]
+ WORD $0xad494c12 // ldp q18, q19, [x0, #288]
+ WORD $0xad4a5414 // ldp q20, q21, [x0, #320]
+ WORD $0xad4b5c16 // ldp q22, q23, [x0, #352]
+ WORD $0xad4c6418 // ldp q24, q25, [x0, #384]
+ WORD $0xad4d6c1a // ldp q26, q27, [x0, #416]
+ WORD $0xad4e741c // ldp q28, q29, [x0, #448]
+ WORD $0xad4f7c1e // ldp q30, q31, [x0, #480]
RET
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
deleted file mode 100644
index c05166fea..000000000
--- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build arm64
-
-package ring0
-
-import (
- "reflect"
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/safecopy"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-const (
- nopInstruction = 0xd503201f
- instSize = unsafe.Sizeof(uint32(0))
- vectorsRawLen = 0x800
-)
-
-func unsafeSlice(addr uintptr, length int) (slice []uint32) {
- hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- hdr.Data = addr
- hdr.Len = length / int(instSize)
- hdr.Cap = length / int(instSize)
- return slice
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func rewriteVectors() {
- vectorsBegin := reflect.ValueOf(Vectors).Pointer()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // And the size is 0x800.
- // Please see the documentation as reference:
- // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
- //
- // But, golang does not allow to set a function's address to a specific value.
- // So, for gvisor, I defined the size of exception-vector-table as 4K,
- // filled the 2nd 2K part with NOP-s.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- //
- // So, the prerequisite for this function to work correctly is:
- // vectorsSafeLen >= 0x1000
- // vectorsRawLen = 0x800
- vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
- if vectorsSafeLen < 2*vectorsRawLen {
- panic("Can't update vectors")
- }
-
- vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
- vectorsRawLen32 := vectorsRawLen / int(instSize)
-
- offset := vectorsBegin & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
-
- _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-
- offset = offset / instSize // By index, not bytes.
- // Move exception-vector-table into the specific address, should uses memmove here.
- for i := 1; i <= vectorsRawLen32; i++ {
- vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
- }
-
- // Adjust branch since instruction was moved forward.
- for i := 0; i < vectorsRawLen32; i++ {
- if vectorsSafeTable[int(offset)+i] != nopInstruction {
- vectorsSafeTable[int(offset)+i] -= uint32(offset)
- }
- }
-
- _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-}
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 53bc3353c..b5652deb9 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -70,6 +70,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc)
fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index bc16a1622..7605d0cb2 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -58,6 +58,15 @@ type PageTables struct {
readOnlyShared bool
}
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
// NewWithUpper returns new PageTables.
//
// upperSharedPageTables are used for mapping the upper of addresses,
@@ -73,14 +82,17 @@ type PageTables struct {
func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables {
p := new(PageTables)
p.Init(a)
+
if upperSharedPageTables != nil {
if !upperSharedPageTables.readOnlyShared {
panic("Only read-only shared pagetables can be used as upper")
}
p.upperSharedPageTables = upperSharedPageTables
p.upperStart = upperStart
- p.cloneUpperShared()
}
+
+ p.InitArch(a)
+
return p
}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index a4e416af7..520161755 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -24,6 +24,14 @@ import (
// archPageTables is architecture-specific data.
type archPageTables struct {
+ // root is the pagetable root for kernel space.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
asid uint16
}
@@ -38,7 +46,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
//
//go:nosplit
func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
- return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+ return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
}
// Bits in page table entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index e7ab887e5..4bdde8448 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -41,13 +41,13 @@ const (
entriesPerPage = 512
)
-// Init initializes a set of PageTables.
+// InitArch does some additional initialization related to the architecture.
//
//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+func (p *PageTables) InitArch(allocator Allocator) {
+ if p.upperSharedPageTables != nil {
+ p.cloneUpperShared()
+ }
}
func pgdIndex(upperStart uintptr) uintptr {
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 5392bf27a..ad0e30c88 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -42,13 +42,16 @@ const (
entriesPerPage = 512
)
-// Init initializes a set of PageTables.
+// InitArch does some additional initialization related to the architecture.
//
//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+func (p *PageTables) InitArch(allocator Allocator) {
+ if p.upperSharedPageTables != nil {
+ p.cloneUpperShared()
+ } else {
+ p.archPageTables.root = p.Allocator.NewPTEs()
+ p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+ }
}
// cloneUpperShared clone the upper from the upper shared page tables.
@@ -59,7 +62,8 @@ func (p *PageTables) cloneUpperShared() {
panic("upperStart should be the same as upperBottom")
}
- // nothing to do for arm.
+ p.archPageTables.root = p.upperSharedPageTables.archPageTables.root
+ p.archPageTables.rootPhysical = p.upperSharedPageTables.archPageTables.rootPhysical
}
// PTEs is a collection of entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
index 157c9a7cc..c261d393a 100644
--- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr {
func (w *Walker) iterateRangeCanonical(start, end uintptr) {
pgdEntryIndex := w.pageTables.root
if start >= upperBottom {
- pgdEntryIndex = w.pageTables.upperSharedPageTables.root
+ pgdEntryIndex = w.pageTables.archPageTables.root
}
for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index a3f775d15..cc1f6bfcc 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..fb7c5dc61 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -23,7 +23,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..b88cdca48 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
)
}
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
return buf
@@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTimestamp = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp)
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
default:
@@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..1f220c343 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
+ var senderAddrBuf []byte
+ if senderRequested {
+ senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
+ }
+ var controlBuf []byte
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
+
// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Only allow known and safe flags.
@@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
- var senderAddr linux.SockAddr
var senderAddrBuf []byte
- if senderRequested {
- senderAddrBuf = make([]byte, sizeofSockaddr)
- }
-
var controlBuf []byte
var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
- }
-
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
-
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err := copyToDst()
if flags&syscall.MSG_DONTWAIT == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
@@ -494,48 +505,75 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
+ switch unixCmsg.Header.Type {
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 3baad098b..057f4d294 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -120,9 +120,6 @@ type socketOpsCommon struct {
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
- // passcred indicates if this socket wants SCM credentials.
- passcred bool
-
// filter indicates that this socket has a BPF filter "installed".
//
// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
@@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
// Passcred implements transport.Credentialer.Passcred.
func (s *socketOpsCommon) Passcred() bool {
- s.mu.Lock()
- passcred := s.passcred
- s.mu.Unlock()
- return passcred
+ return s.ep.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
@@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
}
passcred := usermem.ByteOrder.Uint32(opt)
- s.mu.Lock()
- s.passcred = passcred != 0
- s.mu.Unlock()
+ s.ep.SocketOptions().SetPassCred(passcred != 0)
return nil
case linux.SO_ATTACH_FILTER:
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 7d0ae15ca..23d5cab9c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{
MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
ICMP: tcpip.ICMPStats{
- V4PacketsSent: tcpip.ICMPv4SentPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ V4: tcpip.ICMPv4Stats{
+ PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
},
- V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ V6: tcpip.ICMPv6Stats{
+ PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- V6PacketsSent: tcpip.ICMPv6SentPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ IGMP: tcpip.IGMPStats{
+ PacketsSent: tcpip.IGMPSentPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
},
- V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ PacketsReceived: tcpip.IGMPReceivedPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
@@ -209,18 +235,6 @@ const sizeOfInt32 int = 4
var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
-// ntohs converts a 16-bit number from network byte order to host byte order. It
-// assumes that the host is little endian.
-func ntohs(v uint16) uint16 {
- return v<<8 | v>>8
-}
-
-// htons converts a 16-bit number from host byte order to network byte order. It
-// assumes that the host is little endian.
-func htons(v uint16) uint16 {
- return ntohs(v)
-}
-
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
type commonEndpoint interface {
@@ -240,10 +254,6 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
- // transport.Endpoint.SetSockOptBool.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -252,18 +262,20 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
- // transport.Endpoint.GetSockOpt.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
- // LastError implements tcpip.Endpoint.LastError.
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
+ // LastError implements tcpip.Endpoint.LastError and
+ // transport.Endpoint.LastError.
LastError() *tcpip.Error
- // SocketOptions implements tcpip.Endpoint.SocketOptions.
+ // SocketOptions implements tcpip.Endpoint.SocketOptions and
+ // transport.Endpoint.SocketOptions.
SocketOptions() *tcpip.SocketOptions
}
@@ -308,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -332,9 +344,7 @@ type socketOpsCommon struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
dirent := socket.NewDirent(t, netstackDevice)
@@ -363,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
-// AF_INET6, and AF_PACKET addresses.
-//
-// AddressAndFamily returns an address and its family.
-func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
- // Make sure we have at least 2 bytes for the address family.
- if len(addr) < 2 {
- return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
- }
-
- // Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
- case linux.AF_UNIX:
- path := addr[2:]
- if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- // Drop the terminating NUL (if one exists) and everything after
- // it for filesystem (non-abstract) addresses.
- if len(path) > 0 && path[0] != 0 {
- if n := bytes.IndexByte(path[1:], 0); n >= 0 {
- path = path[:n+1]
- }
- }
- return tcpip.FullAddress{
- Addr: tcpip.Address(path),
- }, family, nil
-
- case linux.AF_INET:
- var a linux.SockAddrInet
- if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- return out, family, nil
-
- case linux.AF_INET6:
- var a linux.SockAddrInet6
- if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- if isLinkLocal(out.Addr) {
- out.NIC = tcpip.NICID(a.Scope_id)
- }
- return out, family, nil
-
- case linux.AF_PACKET:
- var a linux.SockAddrLink
- if len(addr) < sockAddrLinkSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
- if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/173): Return protocol too.
- return tcpip.FullAddress{
- NIC: tcpip.NICID(a.InterfaceIndex),
- Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
- }, family, nil
-
- case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, family, nil
-
- default:
- return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
- }
-}
-
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
@@ -480,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -500,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -721,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -749,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -830,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
} else {
var err *syserr.Error
- addr, family, err = AddressAndFamily(sockaddr)
+ addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -921,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1005,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
case linux.SOL_TCP:
- return getSockOptTCP(t, ep, name, outLen)
+ return getSockOptTCP(t, s, ep, name, outLen)
case linux.SOL_IPV6:
return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
@@ -1041,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1068,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.PasscredOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred()))
+ return &v, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1115,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress()))
+ return &v, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort()))
+ return &v, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1174,24 +1080,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive()))
+ return &v, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1222,34 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum()))
+ return &v, nil
case linux.SO_ACCEPTCONN:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.AcceptConnOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
}
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
@@ -1261,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(!v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption()))
+ return &v, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.CorkOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption()))
+ return &v, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck()))
+ return &v, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1474,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1518,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1536,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
@@ -1545,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1565,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1585,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1607,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1649,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
return &a.(*linux.SockAddrInet).Addr, nil
@@ -1658,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1688,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
- case linux.IP_PKTINFO:
+ case linux.IP_HDRINCL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -1719,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
case linux.IPT_SO_GET_INFO:
@@ -1826,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptSocket(t, s, ep, name, optVal)
case linux.SOL_TCP:
- return setSockOptTCP(t, ep, name, optVal)
+ return setSockOptTCP(t, s, ep, name, optVal)
case linux.SOL_IPV6:
return setSockOptIPv6(t, s, ep, name, optVal)
@@ -1876,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+ ep.SocketOptions().SetReuseAddress(v != 0)
+ return nil
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1884,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+ ep.SocketOptions().SetReusePort(v != 0)
+ return nil
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1923,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
+ ep.SocketOptions().SetPassCred(v != 0)
+ return nil
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1931,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
+ ep.SocketOptions().SetKeepAlive(v != 0)
+ return nil
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1970,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1979,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+ ep.SocketOptions().SetNoChecksum(v != 0)
+ return nil
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
@@ -1993,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -2011,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
-func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if len(optVal) < sizeOfInt32 {
@@ -2019,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+ ep.SocketOptions().SetDelayOption(v == 0)
+ return nil
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -2027,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+ ep.SocketOptions().SetCorkOption(v != 0)
+ return nil
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -2035,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+ ep.SocketOptions().SetQuickAck(v != 0)
+ return nil
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -2147,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
@@ -2174,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2193,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2201,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2276,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2328,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
- InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]),
}))
case linux.IP_MULTICAST_LOOP:
@@ -2337,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2373,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2383,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2393,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2433,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2511,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2535,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -2582,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
}
}
-// isLinkLocal determines if the given IPv6 address is link-local. This is the
-// case when it has the fe80::/10 prefix. This check is used to determine when
-// the NICID is relevant for a given IPv6 address.
-func isLinkLocal(addr tcpip.Address) bool {
- return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
-}
-
-// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
- switch family {
- case linux.AF_UNIX:
- var out linux.SockAddrUnix
- out.Family = linux.AF_UNIX
- l := len([]byte(addr.Addr))
- for i := 0; i < l; i++ {
- out.Path[i] = int8(addr.Addr[i])
- }
-
- // Linux returns the used length of the address struct (including the
- // null terminator) for filesystem paths. The Family field is 2 bytes.
- // It is sometimes allowed to exclude the null terminator if the
- // address length is the max. Abstract and empty paths always return
- // the full exact length.
- if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return &out, uint32(2 + l)
- }
- return &out, uint32(3 + l)
-
- case linux.AF_INET:
- var out linux.SockAddrInet
- copy(out.Addr[:], addr.Addr)
- out.Family = linux.AF_INET
- out.Port = htons(addr.Port)
- return &out, uint32(sockAddrInetSize)
-
- case linux.AF_INET6:
- var out linux.SockAddrInet6
- if len(addr.Addr) == header.IPv4AddressSize {
- // Copy address in v4-mapped format.
- copy(out.Addr[12:], addr.Addr)
- out.Addr[10] = 0xff
- out.Addr[11] = 0xff
- } else {
- copy(out.Addr[:], addr.Addr)
- }
- out.Family = linux.AF_INET6
- out.Port = htons(addr.Port)
- if isLinkLocal(addr.Addr) {
- out.Scope_id = uint32(addr.NIC)
- }
- return &out, uint32(sockAddrInet6Size)
-
- case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
- var out linux.SockAddrLink
- out.Family = linux.AF_PACKET
- out.InterfaceIndex = int32(addr.NIC)
- out.HardwareAddrLen = header.EthernetAddressSize
- copy(out.HardwareAddr[:], addr.Addr)
- return &out, uint32(sockAddrLinkSize)
-
- default:
- return nil, 0
- }
-}
-
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
@@ -2656,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2668,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2686,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
// Always do at least one fetchReadView, even if the number of bytes to
// read is 0.
err = s.fetchReadView()
- if err != nil {
+ if err != nil || len(s.readView) == 0 {
break
}
if dst.NumBytes() == 0 {
@@ -2709,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
}
copied += n
s.readView.TrimFront(n)
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
- }
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
break
}
+ // If we are done reading requested data then stop.
+ if dst.NumBytes() == 0 {
+ break
+ }
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
}
// If we managed to copy something, we must deliver it.
@@ -2812,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
- addr, addrLen = ConvertAddress(s.family, s.sender)
+ addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
- v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
@@ -2833,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2877,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
@@ -2980,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, family, err := AddressAndFamily(to)
+ addrBuf, family, err := socket.AddressAndFamily(to)
if err != nil {
return 0, err
}
@@ -3399,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3408,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3437,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3447,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index b0d9e4d9e..b756bfca0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// NewVFS2 creates a new endpoint socket.
func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
mnt := t.Kernel().SocketMount()
@@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addrLen uint32
if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index ead3b2b79..c847ff1c7 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 2a01143f6..0af805246 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index fa9ac9059..cc0fadeb5 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
- in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
- out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
- Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors.
0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
@@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
0, // Icmp/OutMsgs.
- Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
- out.DstUnreachable.Value(), // OutDestUnreachs.
- out.TimeExceeded.Value(), // OutTimeExcds.
- out.ParamProblem.Value(), // OutParmProbs.
- out.SrcQuench.Value(), // OutSrcQuenchs.
- out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
- out.EchoReply.Value(), // OutEchoReps.
- out.Timestamp.Value(), // OutTimestamps.
- out.TimestampReply.Value(), // OutTimestampReps.
- out.InfoRequest.Value(), // OutAddrMasks.
- out.InfoReply.Value(), // OutAddrMaskReps.
+ Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fd31479e5..bcc426e33 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -18,6 +18,7 @@
package socket
import (
+ "bytes"
"fmt"
"sync/atomic"
"syscall"
@@ -35,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,7 +44,79 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
}
// Release releases Unix domain socket credentials and rights.
@@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
panic(fmt.Sprintf("Unsupported socket family %v", family))
}
}
+
+var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes()
+var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes()
+var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes()
+
+// Ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func Ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// Htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func Htons(v uint16) uint16 {
+ return Ntohs(v)
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = Htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = Htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func BytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 6d9e502bd..9f7aca305 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -118,28 +118,24 @@ var (
// NewConnectioned creates a new unbound connectionedEndpoint.
func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
- return &connectionedEndpoint{
+ return newConnectioned(ctx, stype, uid)
+}
+
+func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
- a := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
- b := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
+ a := newConnectioned(ctx, stype, uid)
+ b := newConnectioned(ctx, stype, uid)
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
q1.InitRefs()
@@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
- return &connectionedEndpoint{
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// ID implements ConnectingEndpoint.ID.
@@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
+ ne.ops.InitHandler(ne)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
readQueue.InitRefs()
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 1406971bc..0813ad87d 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 18a50e9f8..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,8 +16,6 @@
package transport
import (
- "sync/atomic"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -180,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -191,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -203,10 +193,11 @@ type Endpoint interface {
// procfs.
State() uint32
- // LastError implements tcpip.Endpoint.LastError.
+ // LastError clears and returns the last error reported by the endpoint.
LastError() *tcpip.Error
- // SocketOptions implements tcpip.Endpoint.SocketOptions.
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
SocketOptions() *tcpip.SocketOptions
}
@@ -739,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() {
// +stateify savable
type baseEndpoint struct {
*waiter.Queue
-
- // passcred specifies whether SCM_CREDENTIALS socket control messages are
- // enabled on this endpoint. Must be accessed atomically.
- passcred int32
+ tcpip.DefaultSocketOptionsHandler
// Mutex protects the below fields.
sync.Mutex `state:"nosave"`
@@ -758,9 +746,7 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
+ // ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -786,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
// Passcred implements Credentialer.Passcred.
func (e *baseEndpoint) Passcred() bool {
- return atomic.LoadInt32(&e.passcred) != 0
+ return e.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements Credentialer.ConnectedPasscred.
@@ -796,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool {
return e.connected != nil && e.connected.Passcred()
}
-func (e *baseEndpoint) setPasscred(pc bool) {
- if pc {
- atomic.StoreInt32(&e.passcred, 1)
- } else {
- atomic.StoreInt32(&e.passcred, 0)
- }
-}
-
// Connected implements ConnectingEndpoint.Connected.
func (e *baseEndpoint) Connected() bool {
return e.receiver != nil && e.connected != nil
@@ -859,23 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.PasscredOption:
- e.setPasscred(v)
- case tcpip.ReuseAddressOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
return nil
}
@@ -889,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- case tcpip.PasscredOption:
- return e.Passcred(), nil
-
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -966,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 3e520d2ee..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -115,9 +115,6 @@ type socketOpsCommon struct {
// bound, they cannot be modified.
abstractName string
abstractNamespace *kernel.AbstractSocketNamespace
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
}
func (s *socketOpsCommon) isPacket() bool {
@@ -139,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -172,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -184,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -258,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -650,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -685,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index eaf0b0d26..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index a920180d3..d36a64ffc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -32,8 +32,8 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
- "//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
"//pkg/usermem",
],
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index cc5f70cd4..d943a7cb1 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
- "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(b)
+ fa, _, err := socket.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index bb1f715e2..b815e498f 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 0bf313a13..c2285f796 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := ctx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := ctx.Prepare(); err != nil {
+ return err
}
if eventFile != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 519066a47..8db587401 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
- fSetOwn(t, file, set)
+ fSetOwn(t, int(fd), file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
@@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 {
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
- a := file.Async(fasync.New).(*fasync.FileAsync)
+func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error {
+ a := file.Async(fasync.New(fd)).(*fasync.FileAsync)
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
@@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- return 0, nil, fSetOwn(t, file, args[2].Int())
+ return 0, nil, fSetOwn(t, int(fd), file, args[2].Int())
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
@@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- a := file.Async(fasync.New).(*fasync.FileAsync)
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
switch owner.Type {
case linux.F_OWNER_TID:
task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
@@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
n, err := sz.SetFifoSize(int64(args[2].Int()))
return uintptr(n), nil, err
+ case linux.F_GETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return uintptr(a.Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index e383a0a87..1166cd7bb 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -146,11 +146,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getNCnt(t, id, num)
return uintptr(v), nil, err
- case linux.IPC_INFO,
- linux.SEM_INFO,
- linux.SEM_STAT,
- linux.SEM_STAT_ANY:
+ case linux.IPC_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.IPCInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.SemInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_STAT:
+ arg := args[3].Pointer()
+ // id is an index in SEM_STAT.
+ semid, ds, err := semStat(t, id)
+ if err != nil {
+ return 0, nil, err
+ }
+ if _, err := ds.CopyOut(t, arg); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(semid), nil, err
+ case linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -195,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
return set.GetStat(creds)
}
+func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByIndex(index)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ ds, err := set.GetStat(creds)
+ return set.ID, ds, err
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index e748d33d8..d639c9bf7 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
if err := target.SendGroupSignal(info); err != syserror.ESRCH {
return 0, nil, err
}
@@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
err := tg.SendSignal(info)
if err == syserror.ESRCH {
// ESRCH is ignored because it means the task
@@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
// See note above regarding ESRCH race above.
if err := tg.SendSignal(info); err != syserror.ESRCH {
lastErr = err
@@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI
Signo: int32(sig),
Code: arch.SignalInfoTkill,
}
- info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 983f8d396..8e7ac0ffe 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
si := arch.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
- si.SetPid(int32(wr.TID))
- si.SetUid(int32(wr.UID))
+ si.SetPID(int32(wr.TID))
+ si.SetUID(int32(wr.UID))
// TODO(b/73541790): convert kernel.ExitStatus to functions and make
// WaitResult.Status a linux.WaitStatus.
s := syscall.WaitStatus(wr.Status)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 6d0a38330..1365a5a62 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := aioCtx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := aioCtx.Prepare(); err != nil {
+ return err
}
if eventFD != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 36e89700e..7dd9ef857 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
case linux.F_GETOWN_EX:
owner, hasOwner := getAsyncOwner(t, file)
if !hasOwner {
@@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID)
case linux.F_SETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
@@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
case linux.F_SETLK, linux.F_SETLKW:
return 0, nil, posixLock(t, args, file, cmd)
+ case linux.F_GETSIG:
+ a := file.AsyncHandler()
+ if a == nil {
+ // Default behavior aka SIGIO.
+ return 0, nil, nil
+ }
+ return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
@@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne
}
}
-func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error {
switch ownerType {
case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
// Acceptable type.
@@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32
return syserror.EINVAL
}
- a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync)
if pid == 0 {
a.ClearOwner()
return nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 2806c3f6f..20c264fef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
}
ret, err := file.Ioctl(t, t.MemoryManager(), args)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 440c9307c..a3868bf16 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -105,6 +105,7 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index a98aac52b..072655fe8 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
// Add epi to file.epolls so that it is removed when the last
@@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready with the new mask.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
return nil
@@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error
}
// Callback implements waiter.EntryCallback.Callback.
-func (epi *epollInterest) Callback(*waiter.Entry) {
+func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) {
newReady := false
epi.epoll.mu.Lock()
if !epi.ready {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 936f9fc71..5321ac80a 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -15,12 +15,14 @@
package vfs
import (
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
@@ -42,7 +44,7 @@ import (
type FileDescription struct {
FileDescriptionRefs
- // flagsMu protects statusFlags and asyncHandler below.
+ // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below.
flagsMu sync.Mutex `state:"nosave"`
// statusFlags contains status flags, "initialized by open(2) and possibly
@@ -51,6 +53,11 @@ type FileDescription struct {
// access to asyncHandler.
statusFlags uint32
+ // saved is true after beforeSave is called. This is used to prevent
+ // double-unregistration of asyncHandler. This does not work properly for
+ // save-resume, which is not currently supported in gVisor (see b/26588733).
+ saved bool `state:"nosave"`
+
// asyncHandler handles O_ASYNC signal generation. It is set with the
// F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
// also be set by fcntl(2).
@@ -183,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
fd.asyncHandler = nil
@@ -583,7 +590,11 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
if !fd.readable {
return 0, syserror.EBADF
}
- return fd.impl.PRead(ctx, dst, offset, opts)
+ start := fsmetric.StartReadWait()
+ n, err := fd.impl.PRead(ctx, dst, offset, opts)
+ fsmetric.Reads.Increment()
+ fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+ return n, err
}
// Read is similar to PRead, but does not specify an offset.
@@ -591,7 +602,11 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt
if !fd.readable {
return 0, syserror.EBADF
}
- return fd.impl.Read(ctx, dst, opts)
+ start := fsmetric.StartReadWait()
+ n, err := fd.impl.Read(ctx, dst, opts)
+ fsmetric.Reads.Increment()
+ fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+ return n, err
}
// PWrite writes src to the file represented by fd, starting at the given
@@ -825,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
return fd.asyncHandler
}
-// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
-type FileReadWriteSeeker struct {
- FD *FileDescription
- Ctx context.Context
- ROpts ReadOptions
- WOpts WriteOptions
-}
-
-// ReadAt implements io.ReaderAt.ReadAt.
-func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
- return int(n), err
-}
-
-// Read implements io.ReadWriteSeeker.Read.
-func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(n), err
-}
-
-// Seek implements io.ReadWriteSeeker.Seek.
-func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
- return f.FD.Seek(f.Ctx, offset, int32(whence))
-}
-
-// WriteAt implements io.WriterAt.WriteAt.
-func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
- return int(n), err
-}
-
-// Write implements io.ReadWriteSeeker.Write.
-func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
- buf := usermem.BytesIOSequence(p)
- n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(n), err
+// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD
+// returns EOF or an error. It returns the number of bytes copied.
+func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) {
+ done := int64(0)
+ buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := srcFD.Read(ctx, buf, ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ return done, readErr
+ }
+ src := buf.TakeFirst64(readN)
+ for src.NumBytes() != 0 {
+ writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{})
+ done += writeN
+ src = src.DropFirst64(writeN)
+ if writeErr != nil {
+ return done, writeErr
+ }
+ }
+ if readErr == io.EOF {
+ return done, nil
+ }
+ }
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 7723ed643..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -18,8 +18,10 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// FilesystemImplSaveRestoreExtension is an optional extension to
@@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
@@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() {
func (epi *epollInterest) afterLoad() {
// Mark all epollInterests as ready after restore so that the next call to
// EpollInstance.ReadEvents() rechecks their readiness.
- epi.Callback(nil)
+ epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask))
+}
+
+// beforeSave is called by stateify.
+func (fd *FileDescription) beforeSave() {
+ fd.saved = true
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+}
+
+// afterLoad is called by stateify.
+func (fd *FileDescription) afterLoad() {
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Register(fd)
+ }
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 48d6252f7..6fd1bb0b2 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -41,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
@@ -381,6 +382,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ fsmetric.Opens.Increment()
+
// Remove:
//
// - O_CLOEXEC, which affects file descriptors and therefore must be
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 1d1062aeb..8e3146d8d 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -338,6 +338,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
tid := w.k.TaskSet().Root.IDOfTask(t)
buf.WriteString(fmt.Sprintf("\tTask tid: %v (goroutine %d), entered RunSys state %v ago.\n", tid, t.GoroutineID(), now.Sub(o.lastUpdateTime)))
}
+ buf.WriteString("Search for 'goroutine <id>' in the stack dump to find the offending goroutine(s)")
// Force stack dump only if a new task is detected.
w.doAction(w.TaskTimeoutAction, newTaskFound, &buf)
diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go
index d462c3eef..e8315326d 100644
--- a/pkg/shim/v1/proc/process.go
+++ b/pkg/shim/v1/proc/process.go
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package proc contains process-related utilities.
package proc
import (
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
index 05c595bc9..e5b6bf186 100644
--- a/pkg/shim/v1/shim/BUILD
+++ b/pkg/shim/v1/shim/BUILD
@@ -8,6 +8,7 @@ go_library(
"api.go",
"platform.go",
"service.go",
+ "shim.go",
],
visibility = [
"//pkg/shim:__subpackages__",
diff --git a/pkg/sleep/commit_asm.go b/pkg/shim/v1/shim/shim.go
index 75728a97d..1855a8769 100644
--- a/pkg/sleep/commit_asm.go
+++ b/pkg/shim/v1/shim/shim.go
@@ -1,10 +1,11 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
-// http://www.apache.org/licenses/LICENSE-2.0
+// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,9 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
-
-package sleep
-
-// See commit_noasm.go for a description of commitSleep.
-func commitSleep(g uintptr, waitingG *uintptr) bool
+// Package shim contains the core containerd shim implementation.
+package shim
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
index 07e346654..21e75d16d 100644
--- a/pkg/shim/v1/utils/utils.go
+++ b/pkg/shim/v1/utils/utils.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package utils contains utility functions.
package utils
import (
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
index f37fefddc..b0e8daa51 100644
--- a/pkg/shim/v2/BUILD
+++ b/pkg/shim/v2/BUILD
@@ -22,6 +22,7 @@ go_library(
"//runsc/specutils",
"@com_github_burntsushi_toml//:go_default_library",
"@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_cgroups//stats/v1:go_default_library",
"@com_github_containerd_console//:go_default_library",
"@com_github_containerd_containerd//api/events:go_default_library",
"@com_github_containerd_containerd//api/types/task:go_default_library",
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
index 2e39d2c4a..6aaf5fab8 100644
--- a/pkg/shim/v2/service.go
+++ b/pkg/shim/v2/service.go
@@ -28,6 +28,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/containerd/cgroups"
+ cgroupsstats "github.com/containerd/cgroups/stats/v1"
"github.com/containerd/console"
"github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/types/task"
@@ -67,9 +68,15 @@ var (
var _ = (taskAPI.TaskService)(&service{})
-// configFile is the default config file name. For containerd 1.2,
-// we assume that a config.toml should exist in the runtime root.
-const configFile = "config.toml"
+const (
+ // configFile is the default config file name. For containerd 1.2,
+ // we assume that a config.toml should exist in the runtime root.
+ configFile = "config.toml"
+
+ // shimAddressPath is the relative path to a file that contains the address
+ // to the shim UDS. See service.shimAddress.
+ shimAddressPath = "address"
+)
// New returns a new shim service that can be used via GRPC.
func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
@@ -101,6 +108,11 @@ func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()
return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
}
go s.forward(ctx, publisher)
+
+ if address, err := shim.ReadAddress(shimAddressPath); err == nil {
+ s.shimAddress = address
+ }
+
return s, nil
}
@@ -152,6 +164,9 @@ type service struct {
// cancel is a function that needs to be called before the shim stops. The
// function is provided by the caller to New().
cancel func()
+
+ // shimAddress is the location of the UDS used to communicate to containerd.
+ shimAddress string
}
func (s *service) newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
@@ -191,38 +206,58 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container
if err != nil {
return "", err
}
- address, err := shim.SocketAddress(ctx, id)
+ address, err := shim.SocketAddress(ctx, containerdAddress, id)
if err != nil {
return "", err
}
socket, err := shim.NewSocket(address)
if err != nil {
- return "", err
+ // The only time where this would happen is if there is a bug and the socket
+ // was not cleaned up in the cleanup method of the shim or we are using the
+ // grouping functionality where the new process should be run with the same
+ // shim as an existing container.
+ if !shim.SocketEaddrinuse(err) {
+ return "", fmt.Errorf("create new shim socket: %w", err)
+ }
+ if shim.CanConnect(address) {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
+ return "", fmt.Errorf("write existing socket for shim: %w", err)
+ }
+ return address, nil
+ }
+ if err := shim.RemoveSocket(address); err != nil {
+ return "", fmt.Errorf("remove pre-existing socket: %w", err)
+ }
+ if socket, err = shim.NewSocket(address); err != nil {
+ return "", fmt.Errorf("try create new shim socket 2x: %w", err)
+ }
}
- defer socket.Close()
+ cu := cleanup.Make(func() {
+ socket.Close()
+ _ = shim.RemoveSocket(address)
+ })
+ defer cu.Clean()
+
f, err := socket.File()
if err != nil {
return "", err
}
- defer f.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, f)
log.L.Debugf("Executing: %q %s", cmd.Path, cmd.Args)
if err := cmd.Start(); err != nil {
+ f.Close()
return "", err
}
- cu := cleanup.Make(func() {
- cmd.Process.Kill()
- })
- defer cu.Clean()
+ cu.Add(func() { cmd.Process.Kill() })
// make sure to wait after start
go cmd.Wait()
if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
return "", err
}
- if err := shim.WriteAddress("address", address); err != nil {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
return "", err
}
if err := shim.SetScore(cmd.Process.Pid); err != nil {
@@ -675,8 +710,11 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task
func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
log.L.Debugf("Shutdown, id: %s", r.ID)
s.cancel()
+ if s.shimAddress != "" {
+ _ = shim.RemoveSocket(s.shimAddress)
+ }
os.Exit(0)
- return empty, nil
+ panic("Should not get here")
}
func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
@@ -698,48 +736,48 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.
// as runc.
//
// [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
- metrics := &cgroups.Metrics{
- CPU: &cgroups.CPUStat{
- Usage: &cgroups.CPUUsage{
+ metrics := &cgroupsstats.Metrics{
+ CPU: &cgroupsstats.CPUStat{
+ Usage: &cgroupsstats.CPUUsage{
Total: stats.Cpu.Usage.Total,
Kernel: stats.Cpu.Usage.Kernel,
User: stats.Cpu.Usage.User,
PerCPU: stats.Cpu.Usage.Percpu,
},
- Throttling: &cgroups.Throttle{
+ Throttling: &cgroupsstats.Throttle{
Periods: stats.Cpu.Throttling.Periods,
ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
},
},
- Memory: &cgroups.MemoryStat{
+ Memory: &cgroupsstats.MemoryStat{
Cache: stats.Memory.Cache,
- Usage: &cgroups.MemoryEntry{
+ Usage: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Usage.Limit,
Usage: stats.Memory.Usage.Usage,
Max: stats.Memory.Usage.Max,
Failcnt: stats.Memory.Usage.Failcnt,
},
- Swap: &cgroups.MemoryEntry{
+ Swap: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Swap.Limit,
Usage: stats.Memory.Swap.Usage,
Max: stats.Memory.Swap.Max,
Failcnt: stats.Memory.Swap.Failcnt,
},
- Kernel: &cgroups.MemoryEntry{
+ Kernel: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Kernel.Limit,
Usage: stats.Memory.Kernel.Usage,
Max: stats.Memory.Kernel.Max,
Failcnt: stats.Memory.Kernel.Failcnt,
},
- KernelTCP: &cgroups.MemoryEntry{
+ KernelTCP: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.KernelTCP.Limit,
Usage: stats.Memory.KernelTCP.Usage,
Max: stats.Memory.KernelTCP.Max,
Failcnt: stats.Memory.KernelTCP.Failcnt,
},
},
- Pids: &cgroups.PidsStat{
+ Pids: &cgroupsstats.PidsStat{
Current: stats.Pids.Current,
Limit: stats.Pids.Limit,
},
@@ -843,9 +881,7 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er
func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
for e := range s.events {
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := publisher.Publish(ctx, getTopic(e), e)
- cancel()
if err != nil {
// Should not happen.
panic(fmt.Errorf("post event: %w", err))
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index ae0fe1522..48bcdd62b 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -5,10 +5,6 @@ package(licenses = ["notice"])
go_library(
name = "sleep",
srcs = [
- "commit_amd64.s",
- "commit_arm64.s",
- "commit_asm.go",
- "commit_noasm.go",
"sleep_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s
deleted file mode 100644
index d0ef15b20..000000000
--- a/pkg/sleep/commit_arm64.s
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "textflag.h"
-
-#define preparingG 1
-
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVD waitingG+8(FP), R0
- MOVD $preparingG, R1
- MOVD G+0(FP), R2
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
-again:
- LDAXR (R0), R3
- CMP R1, R3
- BNE ok
- STLXR R2, (R0), R3
- CBNZ R3, again
-ok:
- CSET EQ, R0
- MOVB R0, ret+16(FP)
- RET
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
deleted file mode 100644
index f59061f37..000000000
--- a/pkg/sleep/commit_noasm.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build !race
-// +build !amd64,!arm64
-
-package sleep
-
-import "sync/atomic"
-
-// commitSleep signals to wakers that the given g is now sleeping. Wakers can
-// then fetch it and wake it.
-//
-// The commit may fail if wakers have been asserted after our last check, in
-// which case they will have set s.waitingG to zero.
-//
-// It is written in assembly because it is called from g0, so it doesn't have
-// a race context.
-func commitSleep(g uintptr, waitingG *uintptr) bool {
- // Try to store the G so that wakers know who to wake.
- return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
-}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 19bce2afb..c44206b1e 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// Package sleep allows goroutines to efficiently sleep on multiple sources of
// notifications (wakers). It offers O(1) complexity, which is different from
// multi-channel selects which have O(n) complexity (where n is the number of
@@ -91,12 +86,6 @@ var (
assertedSleeper Sleeper
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g uintptr, traceskip int)
-
// Sleeper allows a goroutine to sleep and receive wake up notifications from
// Wakers in an efficient way.
//
@@ -189,7 +178,7 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
// See:runtime2.go in the go runtime package for
// the values to pass as the waitReason here.
const waitReasonSelect = 9
- gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(commitSleep, unsafe.Pointer(&s.waitingG), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
// Pull the shared list out and reverse it in the local
@@ -212,6 +201,18 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
return w
}
+// commitSleep signals to wakers that the given g is now sleeping. Wakers can
+// then fetch it and wake it.
+//
+// The commit may fail if wakers have been asserted after our last check, in
+// which case they will have set s.waitingG to zero.
+//
+//go:norace
+//go:nosplit
+func commitSleep(g uintptr, waitingG unsafe.Pointer) bool {
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g)
+}
+
// Fetch fetches the next wake-up notification. If a notification is immediately
// available, it is returned right away. Otherwise, the behavior depends on the
// value of 'block': if true, the current goroutine blocks until a notification
@@ -311,7 +312,7 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
case 0, preparingG:
default:
// We managed to get a G. Wake it up.
- goready(g, 0)
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
index d3931c952..2b1609af0 100644
--- a/pkg/state/tests/integer_test.go
+++ b/pkg/state/tests/integer_test.go
@@ -20,21 +20,21 @@ import (
)
var (
- allIntTs = []int{-1, 0, 1}
- allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
- allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
- allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
- allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
- allUintTs = []uint{0, 1}
- allUintptrs = []uintptr{0, 1, ^uintptr(0)}
- allUint8s = []uint8{0, 1, math.MaxUint8}
- allUint16s = []uint16{0, 1, math.MaxUint16}
- allUint32s = []uint32{0, 1, math.MaxUint32}
- allUint64s = []uint64{0, 1, math.MaxUint64}
+ allBasicInts = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allBasicUints = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
)
var allInts = flatten(
- allIntTs,
+ allBasicInts,
allInt8s,
allInt16s,
allInt32s,
@@ -42,7 +42,7 @@ var allInts = flatten(
)
var allUints = flatten(
- allUintTs,
+ allBasicUints,
allUintptrs,
allUint8s,
allUint16s,
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 68535c3b1..28e62abbb 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -10,15 +10,34 @@ exports_files(["LICENSE"])
go_template(
name = "generic_atomicptr",
- srcs = ["atomicptr_unsafe.go"],
+ srcs = ["generic_atomicptr_unsafe.go"],
types = [
"Value",
],
)
go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ ":sync",
+ "//pkg/gohacks",
+ ],
+)
+
+go_template(
name = "generic_seqatomic",
- srcs = ["seqatomic_unsafe.go"],
+ srcs = ["generic_seqatomic_unsafe.go"],
types = [
"Value",
],
@@ -31,18 +50,26 @@ go_library(
name = "sync",
srcs = [
"aliases.go",
- "memmove_unsafe.go",
+ "checklocks_off_unsafe.go",
+ "checklocks_on_unsafe.go",
+ "goyield_go113_unsafe.go",
+ "goyield_unsafe.go",
"mutex_unsafe.go",
"nocopy.go",
"norace_unsafe.go",
+ "race_amd64.s",
+ "race_arm64.s",
"race_unsafe.go",
+ "runtime_unsafe.go",
"rwmutex_unsafe.go",
"seqcount.go",
- "spin_unsafe.go",
"sync.go",
],
marshal = False,
stateify = False,
+ deps = [
+ "//pkg/goid",
+ ],
)
go_test(
diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD
new file mode 100644
index 000000000..3f71ae97d
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/BUILD
@@ -0,0 +1,57 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "test_atomicptrmap",
+ out = "test_atomicptrmap_unsafe.go",
+ package = "atomicptrmap",
+ prefix = "test",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_template_instance(
+ name = "test_atomicptrmap_sharded",
+ out = "test_atomicptrmap_sharded_unsafe.go",
+ consts = {
+ "ShardOrder": "4",
+ },
+ package = "atomicptrmap",
+ prefix = "test",
+ suffix = "Sharded",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_library(
+ name = "atomicptrmap",
+ testonly = 1,
+ srcs = [
+ "atomicptrmap.go",
+ "test_atomicptrmap_sharded_unsafe.go",
+ "test_atomicptrmap_unsafe.go",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "atomicptrmap_test",
+ size = "small",
+ srcs = ["atomicptrmap_test.go"],
+ library = ":atomicptrmap",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/sync/atomicptrmaptest/atomicptrmap.go
index 19d6b0b15..867821ce9 100644
--- a/pkg/syncevent/waiter_asm_unsafe.go
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go
@@ -12,13 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
+// Package atomicptrmap instantiates generic_atomicptrmap for testing.
+package atomicptrmap
-package syncevent
-
-import (
- "unsafe"
-)
-
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
+type testValue struct {
+ val int
+}
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
new file mode 100644
index 000000000..75a9997ef
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package atomicptrmap
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestConsistencyWithGoMap(t *testing.T) {
+ const maxKey = 16
+ var vals [4]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ vals[i] = new(testValue)
+ }
+ var (
+ m = make(map[int64]*testValue)
+ apm testAtomicPtrMap
+ )
+ for i := 0; i < 100000; i++ {
+ // Apply a random operation to both m and apm and expect them to have
+ // the same result. Bias toward CompareAndSwap, which has the most
+ // cases; bias away from Range and RangeRepeatable, which are
+ // relatively expensive.
+ switch rand.Intn(10) {
+ case 0, 1: // Load
+ key := rand.Int63n(maxKey)
+ want := m[key]
+ got := apm.Load(key)
+ t.Logf("Load(%d) = %p", key, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 2, 3: // Swap
+ key := rand.Int63n(maxKey)
+ val := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if val != nil {
+ m[key] = val
+ } else {
+ delete(m, key)
+ }
+ got := apm.Swap(key, val)
+ t.Logf("Swap(%d, %p) = %p", key, val, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 4, 5, 6, 7: // CompareAndSwap
+ key := rand.Int63n(maxKey)
+ oldVal := vals[rand.Intn(len(vals))]
+ newVal := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if want == oldVal {
+ if newVal != nil {
+ m[key] = newVal
+ } else {
+ delete(m, key)
+ }
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 8: // Range
+ got := make(map[int64]*testValue)
+ var (
+ haveDup = false
+ dup int64
+ )
+ apm.Range(func(key int64, val *testValue) bool {
+ if _, ok := got[key]; ok && !haveDup {
+ haveDup = true
+ dup = key
+ }
+ got[key] = val
+ return true
+ })
+ t.Logf("Range() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ if haveDup {
+ t.Fatalf("got duplicate key %d", dup)
+ }
+ case 9: // RangeRepeatable
+ got := make(map[int64]*testValue)
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ got[key] = val
+ return true
+ })
+ t.Logf("RangeRepeatable() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ }
+ }
+}
+
+func TestConcurrentHeterogeneous(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ var (
+ apm testAtomicPtrMap
+ wg sync.WaitGroup
+ )
+ defer func() {
+ cancel()
+ wg.Wait()
+ }()
+
+ possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
+ addKeyValuePair := func(key int64, val *testValue) {
+ values := possibleKeyValuePairs[key]
+ if values == nil {
+ values = make(map[*testValue]struct{})
+ possibleKeyValuePairs[key] = values
+ }
+ values[val] = struct{}{}
+ }
+
+ const numValuesPerKey = 4
+
+ // These goroutines use keys not used by any other goroutine.
+ const numPrivateKeys = 3
+ for i := 0; i < numPrivateKeys; i++ {
+ key := int64(i)
+ var vals [numValuesPerKey]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ val := new(testValue)
+ vals[i] = val
+ addKeyValuePair(key, val)
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ var stored *testValue
+ for ctx.Err() == nil {
+ switch r.Intn(4) {
+ case 0:
+ got := apm.Load(key)
+ if got != stored {
+ t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
+ return
+ }
+ case 1:
+ val := vals[r.Intn(len(vals))]
+ want := stored
+ stored = val
+ got := apm.Swap(key, val)
+ if got != want {
+ t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
+ return
+ }
+ case 2, 3:
+ oldVal := vals[r.Intn(len(vals))]
+ newVal := vals[r.Intn(len(vals))]
+ want := stored
+ if stored == oldVal {
+ stored = newVal
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ if got != want {
+ t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // These goroutines share a small set of keys.
+ const numSharedKeys = 2
+ var (
+ sharedKeys [numSharedKeys]int64
+ sharedValues = make(map[int64][]*testValue)
+ sharedValuesSet = make(map[int64]map[*testValue]struct{})
+ )
+ for i := range sharedKeys {
+ key := int64(numPrivateKeys + i)
+ sharedKeys[i] = key
+ vals := make([]*testValue, numValuesPerKey)
+ valsSet := make(map[*testValue]struct{})
+ for j := range vals {
+ val := new(testValue)
+ vals[j] = val
+ valsSet[val] = struct{}{}
+ addKeyValuePair(key, val)
+ }
+ sharedValues[key] = vals
+ sharedValuesSet[key] = valsSet
+ }
+ randSharedValue := func(r *rand.Rand, key int64) *testValue {
+ vals := sharedValues[key]
+ return vals[r.Intn(len(vals))]
+ }
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ keyIndex := r.Intn(len(sharedKeys))
+ key := sharedKeys[keyIndex]
+ var (
+ op string
+ got *testValue
+ )
+ switch r.Intn(4) {
+ case 0:
+ op = "Load"
+ got = apm.Load(key)
+ case 1:
+ op = "Swap"
+ got = apm.Swap(key, randSharedValue(r, key))
+ case 2, 3:
+ op = "CompareAndSwap"
+ got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
+ }
+ if got != nil {
+ valsSet := sharedValuesSet[key]
+ if _, ok := valsSet[got]; !ok {
+ t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // This goroutine repeatedly searches for unused keys.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ key := -1 - r.Int63()
+ if got := apm.Load(key); got != nil {
+ t.Errorf("Load(%d): got %p, wanted nil", key, got)
+ }
+ }
+ }()
+
+ // This goroutine repeatedly calls RangeRepeatable() and checks that each
+ // key corresponds to an expected value.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ abort := false
+ for !abort && ctx.Err() == nil {
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("RangeRepeatable: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ return true
+ })
+ }
+ }()
+
+ // Finally, the main goroutine spins for the length of the test calling
+ // Range() and checking that each key that it observes is unique and
+ // corresponds to an expected value.
+ seenKeys := make(map[int64]struct{})
+ const testDuration = 5 * time.Second
+ end := time.Now().Add(testDuration)
+ abort := false
+ for time.Now().Before(end) {
+ apm.Range(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("Range: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ if _, ok := seenKeys[key]; ok {
+ t.Errorf("Range: got duplicate key %d", key)
+ abort = true
+ return false
+ }
+ seenKeys[key] = struct{}{}
+ return true
+ })
+ if abort {
+ break
+ }
+ for k := range seenKeys {
+ delete(seenKeys, k)
+ }
+ }
+}
+
+type benchmarkableMap interface {
+ Load(key int64) *testValue
+ Store(key int64, val *testValue)
+ LoadOrStore(key int64, val *testValue) (*testValue, bool)
+ Delete(key int64)
+}
+
+// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
+type rwMutexMap struct {
+ mu sync.RWMutex
+ m map[int64]*testValue
+}
+
+func (m *rwMutexMap) Load(key int64) *testValue {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.m[key]
+}
+
+func (m *rwMutexMap) Store(key int64, val *testValue) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ m.m[key] = val
+}
+
+func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ if oldVal, ok := m.m[key]; ok {
+ return oldVal, true
+ }
+ m.m[key] = val
+ return val, false
+}
+
+func (m *rwMutexMap) Delete(key int64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.m, key)
+}
+
+// syncMap implements benchmarkableMap for a sync.Map.
+type syncMap struct {
+ m sync.Map
+}
+
+func (m *syncMap) Load(key int64) *testValue {
+ val, ok := m.m.Load(key)
+ if !ok {
+ return nil
+ }
+ return val.(*testValue)
+}
+
+func (m *syncMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ actual, loaded := m.m.LoadOrStore(key, val)
+ return actual.(*testValue), loaded
+}
+
+func (m *syncMap) Delete(key int64) {
+ m.m.Delete(key)
+}
+
+// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
+type benchmarkableAtomicPtrMap struct {
+ m testAtomicPtrMap
+}
+
+func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
+type benchmarkableAtomicPtrMapSharded struct {
+ m testAtomicPtrMapSharded
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+var mapImpls = [...]struct {
+ name string
+ ctor func() benchmarkableMap
+}{
+ {
+ name: "RWMutexMap",
+ ctor: func() benchmarkableMap {
+ return new(rwMutexMap)
+ },
+ },
+ {
+ name: "SyncMap",
+ ctor: func() benchmarkableMap {
+ return new(syncMap)
+ },
+ },
+ {
+ name: "AtomicPtrMap",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMap)
+ },
+ },
+ {
+ name: "AtomicPtrMapSharded",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMapSharded)
+ },
+ },
+}
+
+func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.LoadOrStore(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkLoadOrStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(i))
+ }
+}
+
+func BenchmarkLookupPositive(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupPositive(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(-1 - i))
+ }
+}
+
+func BenchmarkLookupNegative(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupNegative(b, mapImpl.ctor)
+ })
+ }
+}
+
+type benchmarkConcurrentOptions struct {
+ // loadsPerMutationPair is the number of map lookups between each
+ // insertion/deletion pair.
+ loadsPerMutationPair int
+
+ // If changeKeys is true, the keys used by each goroutine change between
+ // iterations of the test.
+ changeKeys bool
+}
+
+func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
+ var (
+ started sync.WaitGroup
+ workers sync.WaitGroup
+ )
+ started.Add(1)
+
+ m := mapCtor()
+ val := &testValue{}
+ // Insert a large number of unused elements into the map so that used
+ // elements are distributed throughout memory.
+ for i := 0; i < 10000; i++ {
+ m.Store(int64(-1-i), val)
+ }
+ // n := ceil(b.N / (opts.loadsPerMutationPair + 2))
+ n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
+ for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
+ workerID := i
+ workers.Add(1)
+ go func() {
+ defer workers.Done()
+ started.Wait()
+ for i := 0; i < n; i++ {
+ var key int64
+ if opts.changeKeys {
+ key = int64(workerID*n + i)
+ } else {
+ key = int64(workerID)
+ }
+ m.LoadOrStore(key, val)
+ for j := 0; j < opts.loadsPerMutationPair; j++ {
+ m.Load(key)
+ }
+ m.Delete(key)
+ }
+ }()
+ }
+
+ b.ResetTimer()
+ started.Done()
+ workers.Wait()
+}
+
+func BenchmarkConcurrent(b *testing.B) {
+ changeKeysChoices := [...]struct {
+ name string
+ val bool
+ }{
+ {"FixedKeys", false},
+ {"ChangingKeys", true},
+ }
+ writePcts := [...]struct {
+ name string
+ loadsPerMutationPair int
+ }{
+ {"1PercentWrites", 198},
+ {"10PercentWrites", 18},
+ {"50PercentWrites", 2},
+ }
+ for _, changeKeys := range changeKeysChoices {
+ for _, writePct := range writePcts {
+ for _, mapImpl := range mapImpls {
+ name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
+ b.Run(name, func(b *testing.B) {
+ benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
+ loadsPerMutationPair: writePct.loadsPerMutationPair,
+ changeKeys: changeKeys.val,
+ })
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/sync/checklocks_off_unsafe.go b/pkg/sync/checklocks_off_unsafe.go
new file mode 100644
index 000000000..62c81b149
--- /dev/null
+++ b/pkg/sync/checklocks_off_unsafe.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !checklocks
+
+package sync
+
+import (
+ "unsafe"
+)
+
+func noteLock(l unsafe.Pointer) {
+}
+
+func noteUnlock(l unsafe.Pointer) {
+}
diff --git a/pkg/sync/checklocks_on_unsafe.go b/pkg/sync/checklocks_on_unsafe.go
new file mode 100644
index 000000000..24f933ed1
--- /dev/null
+++ b/pkg/sync/checklocks_on_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build checklocks
+
+package sync
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/goid"
+)
+
+// gLocks contains metadata about the locks held by a goroutine.
+type gLocks struct {
+ locksHeld []unsafe.Pointer
+}
+
+// map[goid int]*gLocks
+//
+// Each key may only be written by the G with the goid it refers to.
+//
+// Note that entries are not evicted when a G exit, causing unbounded growth
+// with new G creation / destruction. If this proves problematic, entries could
+// be evicted when no locks are held at the expense of more allocations when
+// taking top-level locks.
+var locksHeld sync.Map
+
+func getGLocks() *gLocks {
+ id := goid.Get()
+
+ var locks *gLocks
+ if l, ok := locksHeld.Load(id); ok {
+ locks = l.(*gLocks)
+ } else {
+ locks = &gLocks{
+ // Initialize space for a few locks.
+ locksHeld: make([]unsafe.Pointer, 0, 8),
+ }
+ locksHeld.Store(id, locks)
+ }
+
+ return locks
+}
+
+func noteLock(l unsafe.Pointer) {
+ locks := getGLocks()
+
+ for _, lock := range locks.locksHeld {
+ if lock == l {
+ panic(fmt.Sprintf("Deadlock on goroutine %d! Double lock of %p: %+v", goid.Get(), l, locks))
+ }
+ }
+
+ // Commit only after checking for panic conditions so that this lock
+ // isn't on the list if the above panic is recovered.
+ locks.locksHeld = append(locks.locksHeld, l)
+}
+
+func noteUnlock(l unsafe.Pointer) {
+ locks := getGLocks()
+
+ if len(locks.locksHeld) == 0 {
+ panic(fmt.Sprintf("Unlock of %p on goroutine %d without any locks held! All locks:\n%s", l, goid.Get(), dumpLocks()))
+ }
+
+ // Search backwards since callers are most likely to unlock in LIFO order.
+ length := len(locks.locksHeld)
+ for i := length - 1; i >= 0; i-- {
+ if l == locks.locksHeld[i] {
+ copy(locks.locksHeld[i:length-1], locks.locksHeld[i+1:length])
+ // Clear last entry to ensure addr can be GC'd.
+ locks.locksHeld[length-1] = nil
+ locks.locksHeld = locks.locksHeld[:length-1]
+ return
+ }
+ }
+
+ panic(fmt.Sprintf("Unlock of %p on goroutine %d without matching lock! All locks:\n%s", l, goid.Get(), dumpLocks()))
+}
+
+func dumpLocks() string {
+ var s strings.Builder
+ locksHeld.Range(func(key, value interface{}) bool {
+ goid := key.(int64)
+ locks := value.(*gLocks)
+
+ // N.B. accessing gLocks of another G is fundamentally racy.
+
+ fmt.Fprintf(&s, "goroutine %d:\n", goid)
+ if len(locks.locksHeld) == 0 {
+ fmt.Fprintf(&s, "\t<none>\n")
+ }
+ for _, lock := range locks.locksHeld {
+ fmt.Fprintf(&s, "\t%p\n", lock)
+ }
+ fmt.Fprintf(&s, "\n")
+
+ return true
+ })
+
+ return s.String()
+}
diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go
index 525c4beed..82b6df18c 100644
--- a/pkg/sync/atomicptr_unsafe.go
+++ b/pkg/sync/generic_atomicptr_unsafe.go
@@ -3,9 +3,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
"sync/atomic"
diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go
new file mode 100644
index 000000000..c70dda6dd
--- /dev/null
+++ b/pkg/sync/generic_atomicptrmap_unsafe.go
@@ -0,0 +1,503 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package atomicptrmap doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package atomicptrmap
+
+import (
+ "reflect"
+ "runtime"
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Key is a required type parameter.
+type Key struct{}
+
+// Value is a required type parameter.
+type Value struct{}
+
+const (
+ // ShardOrder is an optional parameter specifying the base-2 log of the
+ // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
+ // unnecessary synchronization between unrelated concurrent operations,
+ // improving performance for write-heavy workloads, but increase memory
+ // usage for small maps.
+ ShardOrder = 0
+)
+
+// Hasher is an optional type parameter. If Hasher is provided, it must define
+// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
+type Hasher struct {
+ defaultHasher
+}
+
+// defaultHasher is the default Hasher. This indirection exists because
+// defaultHasher must exist even if a custom Hasher is provided, to prevent the
+// Go compiler from complaining about defaultHasher's unused imports.
+type defaultHasher struct {
+ fn func(unsafe.Pointer, uintptr) uintptr
+ seed uintptr
+}
+
+// Init initializes the Hasher.
+func (h *defaultHasher) Init() {
+ h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
+ h.seed = sync.RandUintptr()
+}
+
+// Hash returns the hash value for the given Key.
+func (h *defaultHasher) Hash(key Key) uintptr {
+ return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
+}
+
+var hasher Hasher
+
+func init() {
+ hasher.Init()
+}
+
+// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
+// safe for concurrent use from multiple goroutines without additional
+// synchronization.
+//
+// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
+// use. AtomicPtrMaps must not be copied after first use.
+//
+// sync.Map may be faster than AtomicPtrMap if most operations on the map are
+// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
+// other circumstances.
+type AtomicPtrMap struct {
+ // AtomicPtrMap is implemented as a hash table with the following
+ // properties:
+ //
+ // * Collisions are resolved with quadratic probing. Of the two major
+ // alternatives, Robin Hood linear probing makes it difficult for writers
+ // to execute in parallel, and bucketing is less effective in Go due to
+ // lack of SIMD.
+ //
+ // * The table is optionally divided into shards indexed by hash to further
+ // reduce unnecessary synchronization.
+
+ shards [1 << ShardOrder]apmShard
+}
+
+func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
+ // Go defines right shifts >= width of shifted unsigned operand as 0, so
+ // this is correct even if ShardOrder is 0 (although nogo complains because
+ // nogo is dumb).
+ const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
+ index := hash >> indexLSB
+ return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
+}
+
+type apmShard struct {
+ apmShardMutationData
+ _ [apmShardMutationDataPadding]byte
+ apmShardLookupData
+ _ [apmShardLookupDataPadding]byte
+}
+
+type apmShardMutationData struct {
+ dirtyMu sync.Mutex // serializes slot transitions out of empty
+ dirty uintptr // # slots with val != nil
+ count uintptr // # slots with val != nil and val != tombstone()
+ rehashMu sync.Mutex // serializes rehashing
+}
+
+type apmShardLookupData struct {
+ seq sync.SeqCount // allows atomic reads of slots+mask
+ slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
+ mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq
+}
+
+const (
+ cacheLineBytes = 64
+ // Cache line padding is enabled if sharding is.
+ apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
+ // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
+ // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
+ apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
+ apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding
+ apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
+ apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding
+
+ // These define fractional thresholds for when apmShard.rehash() is called
+ // (i.e. the load factor) and when it rehases to a larger table
+ // respectively. They are chosen such that the rehash threshold = the
+ // expansion threshold + 1/2, so that when reuse of deleted slots is rare
+ // or non-existent, rehashing occurs after the insertion of at least 1/2
+ // the table's size in new entries, which is acceptably infrequent.
+ apmRehashThresholdNum = 2
+ apmRehashThresholdDen = 3
+ apmExpansionThresholdNum = 1
+ apmExpansionThresholdDen = 6
+)
+
+type apmSlot struct {
+ // slot states are indicated by val:
+ //
+ // * Empty: val == nil; key is meaningless. May transition to full or
+ // evacuated with dirtyMu locked.
+ //
+ // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
+ // is the Value mapped to key. May transition to deleted or evacuated.
+ //
+ // * Deleted: val == tombstone(); key is still immutable. key is mapped to
+ // no Value. May transition to full or evacuated.
+ //
+ // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
+ // slots that have already been moved, requiring readers to wait for
+ // rehashing to complete and use the new table. Terminal state.
+ //
+ // Note that once val is non-nil, it cannot become nil again. That is, the
+ // transition from empty to non-empty is irreversible for a given slot;
+ // the only way to create more empty slots is by rehashing.
+ val unsafe.Pointer
+ key Key
+}
+
+func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
+ return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
+}
+
+var tombstoneObj byte
+
+func tombstone() unsafe.Pointer {
+ return unsafe.Pointer(&tombstoneObj)
+}
+
+var evacuatedObj byte
+
+func evacuated() unsafe.Pointer {
+ return unsafe.Pointer(&evacuatedObj)
+}
+
+// Load returns the Value stored in m for key.
+func (m *AtomicPtrMap) Load(key Key) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ return nil
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Empty slot; end of probe sequence.
+ return nil
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// Store stores the Value val for key.
+func (m *AtomicPtrMap) Store(key Key, val *Value) {
+ m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// Swap stores the Value val for key and returns the previously-mapped Value.
+func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
+ return m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
+// stores the Value newVal for key. CompareAndSwap returns the previous Value
+// stored for key, whether or not it stores newVal.
+func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
+ return m.maybeCompareAndSwap(key, true, oldVal, newVal)
+}
+
+func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+ oldVal := tombstone()
+ if typedOldVal != nil {
+ oldVal = unsafe.Pointer(typedOldVal)
+ }
+ newVal := tombstone()
+ if typedNewVal != nil {
+ newVal = unsafe.Pointer(typedNewVal)
+ }
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Need to allocate a table before insertion.
+ shard.rehash(nil)
+ goto retry
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Try to grab this slot for ourselves.
+ shard.dirtyMu.Lock()
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Check if we need to rehash before dirtying a slot.
+ if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
+ shard.dirtyMu.Unlock()
+ shard.rehash(slots)
+ goto retry
+ }
+ slot.key = key
+ atomic.StorePointer(&slot.val, newVal) // transitions slot to full
+ shard.dirty++
+ atomic.AddUintptr(&shard.count, 1)
+ shard.dirtyMu.Unlock()
+ return nil
+ }
+ // Raced with another store; the slot is no longer empty. Continue
+ // with the new value of slotVal since we may have raced with
+ // another store of key.
+ shard.dirtyMu.Unlock()
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ // We're reusing an existing slot, so rehashing isn't necessary.
+ for {
+ if (compare && oldVal != slotVal) || newVal == slotVal {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
+ if slotVal == tombstone() {
+ atomic.AddUintptr(&shard.count, 1)
+ return nil
+ }
+ if newVal == tombstone() {
+ atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
+ }
+ return (*Value)(slotVal)
+ }
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ }
+ }
+ // This produces a triangular number sequence of offsets from the
+ // initially-probed position.
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// rehash is marked nosplit to avoid preemption during table copying.
+//go:nosplit
+func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+
+ if shard.slots != oldSlots {
+ // Raced with another call to rehash().
+ return
+ }
+
+ // Determine the size of the new table. Constraints:
+ //
+ // * The size of the table must be a power of two to ensure that every slot
+ // is visitable by every probe sequence under quadratic probing with
+ // triangular numbers.
+ //
+ // * The size of the table cannot decrease because even if shard.count is
+ // currently smaller than shard.dirty, concurrent stores that reuse
+ // existing slots can drive shard.count back up to a maximum of
+ // shard.dirty.
+ newSize := uintptr(8) // arbitrary initial size
+ if oldSlots != nil {
+ oldSize := shard.mask + 1
+ newSize = oldSize
+ if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
+ newSize *= 2
+ }
+ }
+
+ // Allocate the new table.
+ newSlotsSlice := make([]apmSlot, newSize)
+ newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
+ newSlots := unsafe.Pointer(newSlotsReflect.Data)
+ runtime.KeepAlive(newSlotsSlice)
+ newMask := newSize - 1
+
+ // Start a writer critical section now so that racing users of the old
+ // table that observe evacuated() wait for the new table. (But lock dirtyMu
+ // first since doing so may block, which we don't want to do during the
+ // writer critical section.)
+ shard.dirtyMu.Lock()
+ shard.seq.BeginWrite()
+
+ if oldSlots != nil {
+ realCount := uintptr(0)
+ // Copy old entries to the new table.
+ oldMask := shard.mask
+ for i := uintptr(0); i <= oldMask; i++ {
+ oldSlot := apmSlotAt(oldSlots, i)
+ val := atomic.SwapPointer(&oldSlot.val, evacuated())
+ if val == nil || val == tombstone() {
+ continue
+ }
+ hash := hasher.Hash(oldSlot.key)
+ j := hash & newMask
+ inc := uintptr(1)
+ for {
+ newSlot := apmSlotAt(newSlots, j)
+ if newSlot.val == nil {
+ newSlot.val = val
+ newSlot.key = oldSlot.key
+ break
+ }
+ j = (j + inc) & newMask
+ inc++
+ }
+ realCount++
+ }
+ // Update dirty to reflect that tombstones were not copied to the new
+ // table. Use realCount since a concurrent mutator may not have updated
+ // shard.count yet.
+ shard.dirty = realCount
+ }
+
+ // Switch to the new table.
+ atomic.StorePointer(&shard.slots, newSlots)
+ atomic.StoreUintptr(&shard.mask, newMask)
+
+ shard.seq.EndWrite()
+ shard.dirtyMu.Unlock()
+}
+
+// Range invokes f on each Key-Value pair stored in m. If any call to f returns
+// false, Range stops iteration and returns.
+//
+// Range does not necessarily correspond to any consistent snapshot of the
+// Map's contents: no Key will be visited more than once, but if the Value for
+// any Key is stored or deleted concurrently, Range may reflect any mapping for
+// that Key from any point during the Range call.
+//
+// f must not call other methods on m.
+func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+ if !shard.doRange(f) {
+ return
+ }
+ }
+}
+
+func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
+ // We have to lock rehashMu because if we handled races with rehashing by
+ // retrying, f could see the same key twice.
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+ slots := shard.slots
+ if slots == nil {
+ return true
+ }
+ mask := shard.mask
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return false
+ }
+ }
+ return true
+}
+
+// RangeRepeatable is like Range, but:
+//
+// * RangeRepeatable may visit the same Key multiple times in the presence of
+// concurrent mutators, possibly passing different Values to f in different
+// calls.
+//
+// * It is safe for f to call other methods on m.
+func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+
+ retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ continue
+ }
+
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return
+ }
+ }
+ }
+}
diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go
index 2184cb5ab..82b676abf 100644
--- a/pkg/sync/seqatomic_unsafe.go
+++ b/pkg/sync/generic_seqatomic_unsafe.go
@@ -3,25 +3,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
- "fmt"
- "reflect"
- "strings"
"unsafe"
"gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
-//
-// Value must not contain any pointers, including interface objects, function
-// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
-// containing any of the above. An init() function will panic if this property
-// does not hold.
type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
@@ -55,12 +47,3 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
ok = seq.ReadOk(epoch)
return
}
-
-func init() {
- var val Value
- typ := reflect.TypeOf(val)
- name := typ.Name()
- if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
- panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
- }
-}
diff --git a/pkg/sync/goyield_go113_unsafe.go b/pkg/sync/goyield_go113_unsafe.go
new file mode 100644
index 000000000..8aee0d455
--- /dev/null
+++ b/pkg/sync/goyield_go113_unsafe.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.14
+
+package sync
+
+import (
+ "runtime"
+)
+
+func goyield() {
+ // goyield is not available until Go 1.14.
+ runtime.Gosched()
+}
diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/goyield_unsafe.go
index cafb2d065..672ee274d 100644
--- a/pkg/sync/spin_unsafe.go
+++ b/pkg/sync/goyield_unsafe.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.13
+// +build go1.14
// +build !go1.17
// Check go:linkname function signatures when updating Go version.
@@ -14,11 +14,5 @@ import (
_ "unsafe" // for go:linkname
)
-//go:linkname canSpin sync.runtime_canSpin
-func canSpin(i int) bool
-
-//go:linkname doSpin sync.runtime_doSpin
-func doSpin()
-
//go:linkname goyield runtime.goyield
func goyield()
diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
deleted file mode 100644
index f5e630009..000000000
--- a/pkg/sync/memmove_unsafe.go
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
-package sync
-
-import (
- "unsafe"
-)
-
-//go:linkname memmove runtime.memmove
-//go:noescape
-func memmove(to, from unsafe.Pointer, n uintptr)
-
-// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't
-// define it because go_generics can't update the go:linkname annotation.
-// Furthermore, go:linkname silently doesn't work if the local name is exported
-// (this is of course undocumented), which is why this indirection is
-// necessary.
-func Memmove(to, from unsafe.Pointer, n uintptr) {
- memmove(to, from, n)
-}
diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go
index 0838248b4..4fb51a8ab 100644
--- a/pkg/sync/mutex_test.go
+++ b/pkg/sync/mutex_test.go
@@ -32,11 +32,11 @@ func TestStructSize(t *testing.T) {
func TestFieldValues(t *testing.T) {
var m Mutex
m.Lock()
- if got := *m.state(); got != mutexLocked {
+ if got := *m.m.state(); got != mutexLocked {
t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked)
}
m.Unlock()
- if got := *m.state(); got != mutexUnlocked {
+ if got := *m.m.state(); got != mutexUnlocked {
t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked)
}
}
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index f4c2e9642..21084b857 100644
--- a/pkg/sync/mutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
@@ -17,8 +17,9 @@ import (
"unsafe"
)
-// Mutex is a try lock.
-type Mutex struct {
+// CrossGoroutineMutex is equivalent to Mutex, but it need not be unlocked by a
+// the same goroutine that locked the mutex.
+type CrossGoroutineMutex struct {
sync.Mutex
}
@@ -27,7 +28,7 @@ type syncMutex struct {
sema uint32
}
-func (m *Mutex) state() *int32 {
+func (m *CrossGoroutineMutex) state() *int32 {
return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state
}
@@ -36,9 +37,9 @@ const (
mutexLocked = 1
)
-// TryLock tries to aquire the mutex. It returns true if it succeeds and false
+// TryLock tries to acquire the mutex. It returns true if it succeeds and false
// otherwise. TryLock does not block.
-func (m *Mutex) TryLock() bool {
+func (m *CrossGoroutineMutex) TryLock() bool {
if atomic.CompareAndSwapInt32(m.state(), mutexUnlocked, mutexLocked) {
if RaceEnabled {
RaceAcquire(unsafe.Pointer(&m.Mutex))
@@ -47,3 +48,43 @@ func (m *Mutex) TryLock() bool {
}
return false
}
+
+// Mutex is a mutual exclusion lock. The zero value for a Mutex is an unlocked
+// mutex.
+//
+// A Mutex must not be copied after first use.
+//
+// A Mutex must be unlocked by the same goroutine that locked it. This
+// invariant is enforced with the 'checklocks' build tag.
+type Mutex struct {
+ m CrossGoroutineMutex
+}
+
+// Lock locks m. If the lock is already in use, the calling goroutine blocks
+// until the mutex is available.
+func (m *Mutex) Lock() {
+ noteLock(unsafe.Pointer(m))
+ m.m.Lock()
+}
+
+// Unlock unlocks m.
+//
+// Preconditions:
+// * m is locked.
+// * m was locked by this goroutine.
+func (m *Mutex) Unlock() {
+ noteUnlock(unsafe.Pointer(m))
+ m.m.Unlock()
+}
+
+// TryLock tries to acquire the mutex. It returns true if it succeeds and false
+// otherwise. TryLock does not block.
+func (m *Mutex) TryLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(m))
+ locked := m.m.TryLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(m))
+ }
+ return locked
+}
diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index 006055dd6..70b5f3a5e 100644
--- a/pkg/sync/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -8,6 +8,7 @@
package sync
import (
+ "sync/atomic"
"unsafe"
)
@@ -33,3 +34,13 @@ func RaceRelease(addr unsafe.Pointer) {
// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
func RaceReleaseMerge(addr unsafe.Pointer) {
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool {
+ // Use atomic.CompareAndSwapUintptr outside of race builds for
+ // inlinability.
+ return atomic.CompareAndSwapUintptr(ptr, old, new)
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/sync/race_amd64.s
index 5e216b045..57bc0ec79 100644
--- a/pkg/syncevent/waiter_amd64.s
+++ b/pkg/sync/race_amd64.s
@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build amd64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
MOVQ ptr+0(FP), DI
- MOVQ wg+8(FP), SI
+ MOVQ old+8(FP), AX
+ MOVQ new+16(FP), SI
- MOVQ $·preparingG(SB), AX
LOCK
- CMPXCHGQ DI, 0(SI)
+ CMPXCHGQ SI, 0(DI)
SETEQ AX
- MOVB AX, ret+16(FP)
+ MOVB AX, ret+24(FP)
RET
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/sync/race_arm64.s
index f4c06f194..88f091fda 100644
--- a/pkg/syncevent/waiter_arm64.s
+++ b/pkg/sync/race_arm64.s
@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build arm64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
- MOVD wg+8(FP), R0
- MOVD $·preparingG(SB), R1
- MOVD ptr+0(FP), R2
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
+ MOVD ptr+0(FP), R0
+ MOVD old+8(FP), R1
+ MOVD new+16(FP), R1
again:
LDAXR (R0), R3
CMP R1, R3
@@ -29,6 +30,6 @@ again:
CBNZ R3, again
ok:
CSET EQ, R0
- MOVB R0, ret+16(FP)
+ MOVB R0, ret+24(FP)
RET
diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go
index 31d8fa9a6..59985c270 100644
--- a/pkg/sync/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -39,3 +39,9 @@ func RaceRelease(addr unsafe.Pointer) {
func RaceReleaseMerge(addr unsafe.Pointer) {
runtime.RaceReleaseMerge(addr)
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
new file mode 100644
index 000000000..e925e2e5b
--- /dev/null
+++ b/pkg/sync/runtime_unsafe.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.17
+
+// Check function signatures and constants when updating Go version.
+
+package sync
+
+import (
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// Note that go:linkname silently doesn't work if the local name is exported,
+// necessitating an indirection for exported functions.
+
+// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
+//
+//go:nosplit
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
+
+// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock);
+// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g)
+// is called. unlockf and its callees must be nosplit and norace, since stack
+// splitting and race context are not available where it is called.
+//
+//go:nosplit
+func Gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int) {
+ gopark(unlockf, lock, reason, traceEv, traceskip)
+}
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+// Goready is runtime.goready.
+//
+//go:nosplit
+func Goready(gp uintptr, traceskip int) {
+ goready(gp, traceskip)
+}
+
+//go:linkname goready runtime.goready
+func goready(gp uintptr, traceskip int)
+
+// Values for the reason argument to gopark, from Go's src/runtime/runtime2.go.
+const (
+ WaitReasonSelect uint8 = 9
+)
+
+// Values for the traceEv argument to gopark, from Go's src/runtime/trace.go.
+const (
+ TraceEvGoBlockSelect byte = 24
+)
+
+// Rand32 returns a non-cryptographically-secure random uint32.
+func Rand32() uint32 {
+ return fastrand()
+}
+
+// Rand64 returns a non-cryptographically-secure random uint64.
+func Rand64() uint64 {
+ return uint64(fastrand())<<32 | uint64(fastrand())
+}
+
+//go:linkname fastrand runtime.fastrand
+func fastrand() uint32
+
+// RandUintptr returns a non-cryptographically-secure random uintptr.
+func RandUintptr() uintptr {
+ if unsafe.Sizeof(uintptr(0)) == 4 {
+ return uintptr(Rand32())
+ }
+ return uintptr(Rand64())
+}
+
+// MapKeyHasher returns a hash function for pointers of m's key type.
+//
+// Preconditions: m must be a map.
+func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr {
+ if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map {
+ panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp))
+ }
+ mtyp := *(**maptype)(unsafe.Pointer(&m))
+ return mtyp.hasher
+}
+
+type maptype struct {
+ size uintptr
+ ptrdata uintptr
+ hash uint32
+ tflag uint8
+ align uint8
+ fieldAlign uint8
+ kind uint8
+ equal func(unsafe.Pointer, unsafe.Pointer) bool
+ gcdata *byte
+ str int32
+ ptrToThis int32
+ key unsafe.Pointer
+ elem unsafe.Pointer
+ bucket unsafe.Pointer
+ hasher func(unsafe.Pointer, uintptr) uintptr
+ // more fields
+}
+
+// These functions are only used within the sync package.
+
+//go:linkname semacquire sync.runtime_Semacquire
+func semacquire(s *uint32)
+
+//go:linkname semrelease sync.runtime_Semrelease
+func semrelease(s *uint32, handoff bool, skipframes int)
+
+//go:linkname canSpin sync.runtime_canSpin
+func canSpin(i int) bool
+
+//go:linkname doSpin sync.runtime_doSpin
+func doSpin()
diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go
index ce667e825..5ca96d12b 100644
--- a/pkg/sync/rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
@@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c
}
for i := 0; i < 100; i++ {
}
- n = atomic.AddInt32(activity, -1)
+ atomic.AddInt32(activity, -1)
rwm.RUnlock()
}
cdone <- true
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index b3b4dee78..4cf3fcd6e 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -3,11 +3,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.13
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// This is mostly copied from the standard library's sync/rwmutex.go.
//
// Happens-before relationships indicated to the race detector:
@@ -23,16 +18,15 @@ import (
"unsafe"
)
-//go:linkname runtimeSemacquire sync.runtime_Semacquire
-func runtimeSemacquire(s *uint32)
-
-//go:linkname runtimeSemrelease sync.runtime_Semrelease
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
-
-// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock,
-// TryLock and TryRLock methods.
-type RWMutex struct {
- w Mutex // held if there are pending writers
+// CrossGoroutineRWMutex is equivalent to RWMutex, but it need not be unlocked
+// by a the same goroutine that locked the mutex.
+type CrossGoroutineRWMutex struct {
+ // w is held if there are pending writers
+ //
+ // We use CrossGoroutineMutex rather than Mutex because the lock
+ // annotation instrumentation in Mutex will trigger false positives in
+ // the race detector when called inside of RaceDisable.
+ w CrossGoroutineMutex
writerSem uint32 // semaphore for writers to wait for completing readers
readerSem uint32 // semaphore for readers to wait for completing writers
readerCount int32 // number of pending readers
@@ -43,7 +37,7 @@ const rwmutexMaxReaders = 1 << 30
// TryRLock locks rw for reading. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryRLock() bool {
+func (rw *CrossGoroutineRWMutex) TryRLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -67,13 +61,17 @@ func (rw *RWMutex) TryRLock() bool {
}
// RLock locks rw for reading.
-func (rw *RWMutex) RLock() {
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *CrossGoroutineRWMutex) RLock() {
if RaceEnabled {
RaceDisable()
}
if atomic.AddInt32(&rw.readerCount, 1) < 0 {
// A writer is pending, wait for it.
- runtimeSemacquire(&rw.readerSem)
+ semacquire(&rw.readerSem)
}
if RaceEnabled {
RaceEnable()
@@ -82,7 +80,10 @@ func (rw *RWMutex) RLock() {
}
// RUnlock undoes a single RLock call.
-func (rw *RWMutex) RUnlock() {
+//
+// Preconditions:
+// * rw is locked for reading.
+func (rw *CrossGoroutineRWMutex) RUnlock() {
if RaceEnabled {
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
@@ -94,7 +95,7 @@ func (rw *RWMutex) RUnlock() {
// A writer is pending.
if atomic.AddInt32(&rw.readerWait, -1) == 0 {
// The last reader unblocks the writer.
- runtimeSemrelease(&rw.writerSem, false, 0)
+ semrelease(&rw.writerSem, false, 0)
}
}
if RaceEnabled {
@@ -104,7 +105,7 @@ func (rw *RWMutex) RUnlock() {
// TryLock locks rw for writing. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryLock() bool {
+func (rw *CrossGoroutineRWMutex) TryLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -130,8 +131,9 @@ func (rw *RWMutex) TryLock() bool {
return true
}
-// Lock locks rw for writing.
-func (rw *RWMutex) Lock() {
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *CrossGoroutineRWMutex) Lock() {
if RaceEnabled {
RaceDisable()
}
@@ -141,7 +143,7 @@ func (rw *RWMutex) Lock() {
r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders
// Wait for active readers.
if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 {
- runtimeSemacquire(&rw.writerSem)
+ semacquire(&rw.writerSem)
}
if RaceEnabled {
RaceEnable()
@@ -150,7 +152,10 @@ func (rw *RWMutex) Lock() {
}
// Unlock unlocks rw for writing.
-func (rw *RWMutex) Unlock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) Unlock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.writerSem))
RaceRelease(unsafe.Pointer(&rw.readerSem))
@@ -163,7 +168,7 @@ func (rw *RWMutex) Unlock() {
}
// Unblock blocked readers, if any.
for i := 0; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed.
rw.w.Unlock()
@@ -173,7 +178,10 @@ func (rw *RWMutex) Unlock() {
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
-func (rw *RWMutex) DowngradeLock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) DowngradeLock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.readerSem))
RaceDisable()
@@ -186,7 +194,7 @@ func (rw *RWMutex) DowngradeLock() {
// Unblock blocked readers, if any. Note that this loop starts as 1 since r
// includes this goroutine.
for i := 1; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed to rw.w.Lock(). Note that they will still
// block on rw.writerSem since at least this reader exists, such that
@@ -196,3 +204,91 @@ func (rw *RWMutex) DowngradeLock() {
RaceEnable()
}
}
+
+// A RWMutex is a reader/writer mutual exclusion lock. The lock can be held by
+// an arbitrary number of readers or a single writer. The zero value for a
+// RWMutex is an unlocked mutex.
+//
+// A RWMutex must not be copied after first use.
+//
+// If a goroutine holds a RWMutex for reading and another goroutine might call
+// Lock, no goroutine should expect to be able to acquire a read lock until the
+// initial read lock is released. In particular, this prohibits recursive read
+// locking. This is to ensure that the lock eventually becomes available; a
+// blocked Lock call excludes new readers from acquiring the lock.
+//
+// A Mutex must be unlocked by the same goroutine that locked it. This
+// invariant is enforced with the 'checklocks' build tag.
+type RWMutex struct {
+ m CrossGoroutineRWMutex
+}
+
+// TryRLock locks rw for reading. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryRLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryRLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// RLock locks rw for reading.
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *RWMutex) RLock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.RLock()
+}
+
+// RUnlock undoes a single RLock call.
+//
+// Preconditions:
+// * rw is locked for reading.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) RUnlock() {
+ rw.m.RUnlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// TryLock locks rw for writing. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *RWMutex) Lock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.Lock()
+}
+
+// Unlock unlocks rw for writing.
+//
+// Preconditions:
+// * rw is locked for writing.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) Unlock() {
+ rw.m.Unlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// DowngradeLock atomically unlocks rw for writing and locks it for reading.
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *RWMutex) DowngradeLock() {
+ // No note change for DowngradeLock.
+ rw.m.DowngradeLock()
+}
diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go
index 2c5d3df99..1f025f33c 100644
--- a/pkg/sync/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -6,8 +6,6 @@
package sync
import (
- "fmt"
- "reflect"
"sync/atomic"
)
@@ -27,9 +25,6 @@ import (
// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
// operations to be made atomic with reads of SeqCount-protected data.
//
-// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
-// cannot include pointers.
-//
// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
// data require instantiating function templates using go_generics (see
// seqatomic.go).
@@ -128,32 +123,3 @@ func (s *SeqCount) EndWrite() {
panic("SeqCount.EndWrite outside writer critical section")
}
}
-
-// PointersInType returns a list of pointers reachable from values named
-// valName of the given type.
-//
-// PointersInType is not exhaustive, but it is guaranteed that if typ contains
-// at least one pointer, then PointersInTypeOf returns a non-empty list.
-func PointersInType(typ reflect.Type, valName string) []string {
- switch kind := typ.Kind(); kind {
- case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
- return nil
-
- case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
- return []string{valName}
-
- case reflect.Array:
- return PointersInType(typ.Elem(), valName+"[]")
-
- case reflect.Struct:
- var ptrs []string
- for i, n := 0, typ.NumField(); i < n; i++ {
- field := typ.Field(i)
- ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
- }
- return ptrs
-
- default:
- return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
- }
-}
diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go
index 6eb7b4b59..3f5592e3e 100644
--- a/pkg/sync/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -6,7 +6,6 @@
package sync
import (
- "reflect"
"testing"
"time"
)
@@ -99,55 +98,3 @@ func BenchmarkSeqCountReadUncontended(b *testing.B) {
}
})
}
-
-func TestPointersInType(t *testing.T) {
- for _, test := range []struct {
- name string // used for both test and value name
- val interface{}
- ptrs []string
- }{
- {
- name: "EmptyStruct",
- val: struct{}{},
- },
- {
- name: "Int",
- val: int(0),
- },
- {
- name: "MixedStruct",
- val: struct {
- b bool
- I int
- ExportedPtr *struct{}
- unexportedPtr *struct{}
- arr [2]int
- ptrArr [2]*int
- nestedStruct struct {
- nestedNonptr int
- nestedPtr *int
- }
- structArr [1]struct {
- nonptr int
- ptr *int
- }
- }{},
- ptrs: []string{
- "MixedStruct.ExportedPtr",
- "MixedStruct.unexportedPtr",
- "MixedStruct.ptrArr[]",
- "MixedStruct.nestedStruct.nestedPtr",
- "MixedStruct.structArr[].ptr",
- },
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- typ := reflect.TypeOf(test.val)
- ptrs := PointersInType(typ, test.name)
- t.Logf("Found pointers: %v", ptrs)
- if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
- t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
- }
- })
- }
-}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
index 0500a22cf..42c553308 100644
--- a/pkg/syncevent/BUILD
+++ b/pkg/syncevent/BUILD
@@ -9,10 +9,6 @@ go_library(
"receiver.go",
"source.go",
"syncevent.go",
- "waiter_amd64.s",
- "waiter_arm64.s",
- "waiter_asm_unsafe.go",
- "waiter_noasm_unsafe.go",
"waiter_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
deleted file mode 100644
index 0f74a689c..000000000
--- a/pkg/syncevent/waiter_noasm_unsafe.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// waiterUnlock is called from g0, so when the race detector is enabled,
-// waiterUnlock must be implemented in assembly since no race context is
-// available.
-//
-// +build !race
-// +build !amd64,!arm64
-
-package syncevent
-
-import (
- "sync/atomic"
- "unsafe"
-)
-
-// waiterUnlock is the "unlock function" passed to runtime.gopark by
-// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
-// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
-// should be aborted.
-//
-//go:nosplit
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool {
- // The only way this CAS can fail is if a call to Waiter.NotifyPending()
- // has replaced *wg with nil, in which case we should not sleep.
- return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr)
-}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
index 518f18479..b6ed2852d 100644
--- a/pkg/syncevent/waiter_unsafe.go
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package syncevent
import (
@@ -26,17 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g unsafe.Pointer, traceskip int)
-
-const (
- waitReasonSelect = 9 // Go: src/runtime/runtime2.go
- traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
-)
-
// Waiter allows a goroutine to block on pending events received by a Receiver.
//
// Waiter.Init() must be called before first use.
@@ -45,20 +29,19 @@ type Waiter struct {
// g is one of:
//
- // - nil: No goroutine is blocking in Wait.
+ // - 0: No goroutine is blocking in Wait.
//
- // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // - preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
// completed waiterUnlock(). Thus the wait can only be interrupted by
- // replacing the value of g with nil (the G may not be in state Gwaiting
- // yet, so we can't call goready.)
+ // replacing the value of g with 0 (the G may not be in state Gwaiting yet,
+ // so we can't call goready.)
//
// - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
// goroutine blocked in Wait, which can only be woken by calling goready.
- g unsafe.Pointer `state:"zerovalue"`
+ g uintptr `state:"zerovalue"`
}
-// Sentinel object for Waiter.g.
-var preparingG struct{}
+const preparingG = 1
// Init must be called before first use of w.
func (w *Waiter) Init() {
@@ -99,21 +82,29 @@ func (w *Waiter) WaitFor(es Set) Set {
}
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if p := w.r.Pending(); p&es != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
}
+//go:norace
+//go:nosplit
+func waiterCommit(g uintptr, wg unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(wg), preparingG, g)
+}
+
// Ack marks the given events as not pending.
func (w *Waiter) Ack(es Set) {
w.r.Ack(es)
@@ -135,20 +126,20 @@ func (w *Waiter) WaitAndAckAll() Set {
for {
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if w.r.Pending() != NoEvents {
if p := w.r.PendingAndAckAll(); p != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
// Check for pending events. We call PendingAndAckAll() directly now since
// we only expect to be woken after events become pending.
@@ -171,14 +162,14 @@ func (w *Waiter) NotifyPending() {
// goroutine. NotifyPending is called after w.r.Pending() is updated, so
// concurrent and future calls to w.Wait() will observe pending events and
// abort sleeping.
- if atomic.LoadPointer(&w.g) == nil {
+ if atomic.LoadUintptr(&w.g) == 0 {
return
}
// Wake a sleeping G, or prevent a G that is preparing to sleep from doing
// so. Swap is needed here to ensure that only one call to NotifyPending
// calls goready.
- if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
- goready(g, 0)
+ if g := atomic.SwapUintptr(&w.g, 0); g > preparingG {
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go
index fc6ef60a1..77faa3670 100644
--- a/pkg/syserr/host_linux.go
+++ b/pkg/syserr/host_linux.go
@@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation
// FromHost translates a syscall.Errno to a corresponding Error value.
func FromHost(err syscall.Errno) *Error {
- if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
+ if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err))
}
return linuxHostTranslations[err].err
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 5ae10939d..77c3c110c 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -15,6 +15,8 @@
package syserr
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,45 +50,56 @@ var (
ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
)
-var netstackErrorTranslations = map[*tcpip.Error]*Error{
- tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
- tcpip.ErrUnknownNICID: ErrUnknownNICID,
- tcpip.ErrUnknownDevice: ErrUnknownDevice,
- tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
- tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
- tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
- tcpip.ErrNoRoute: ErrNoRoute,
- tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
- tcpip.ErrAlreadyBound: ErrAlreadyBound,
- tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
- tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
- tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
- tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
- tcpip.ErrPortInUse: ErrPortInUse,
- tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
- tcpip.ErrClosedForSend: ErrClosedForSend,
- tcpip.ErrClosedForReceive: ErrClosedForReceive,
- tcpip.ErrWouldBlock: ErrWouldBlock,
- tcpip.ErrConnectionRefused: ErrConnectionRefused,
- tcpip.ErrTimeout: ErrTimeout,
- tcpip.ErrAborted: ErrAborted,
- tcpip.ErrConnectStarted: ErrConnectStarted,
- tcpip.ErrDestinationRequired: ErrDestinationRequired,
- tcpip.ErrNotSupported: ErrNotSupported,
- tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
- tcpip.ErrNotConnected: ErrNotConnected,
- tcpip.ErrConnectionReset: ErrConnectionReset,
- tcpip.ErrConnectionAborted: ErrConnectionAborted,
- tcpip.ErrNoSuchFile: ErrNoSuchFile,
- tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
- tcpip.ErrNoLinkAddress: ErrHostDown,
- tcpip.ErrBadAddress: ErrBadAddress,
- tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
- tcpip.ErrMessageTooLong: ErrMessageTooLong,
- tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
- tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
- tcpip.ErrNotPermitted: ErrNotPermittedNet,
- tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
+var netstackErrorTranslations map[string]*Error
+
+func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) {
+ key := tcpipErr.String()
+ if _, ok := netstackErrorTranslations[key]; ok {
+ panic(fmt.Sprintf("duplicate error key: %s", key))
+ }
+ netstackErrorTranslations[key] = netstackErr
+}
+
+func init() {
+ netstackErrorTranslations = make(map[string]*Error)
+ addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol)
+ addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID)
+ addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice)
+ addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption)
+ addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID)
+ addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress)
+ addErrMapping(tcpip.ErrNoRoute, ErrNoRoute)
+ addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint)
+ addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound)
+ addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState)
+ addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting)
+ addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected)
+ addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable)
+ addErrMapping(tcpip.ErrPortInUse, ErrPortInUse)
+ addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress)
+ addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend)
+ addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive)
+ addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock)
+ addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused)
+ addErrMapping(tcpip.ErrTimeout, ErrTimeout)
+ addErrMapping(tcpip.ErrAborted, ErrAborted)
+ addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted)
+ addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired)
+ addErrMapping(tcpip.ErrNotSupported, ErrNotSupported)
+ addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported)
+ addErrMapping(tcpip.ErrNotConnected, ErrNotConnected)
+ addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset)
+ addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted)
+ addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile)
+ addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue)
+ addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown)
+ addErrMapping(tcpip.ErrBadAddress, ErrBadAddress)
+ addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable)
+ addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong)
+ addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace)
+ addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
+ addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
+ addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
@@ -95,7 +108,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error {
if err == nil {
return nil
}
- se, ok := netstackErrorTranslations[err]
+ se, ok := netstackErrorTranslations[err.String()]
if !ok {
panic("Unknown error: " + err.String())
}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 81f762e10..91971b687 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"reflect"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker {
}
}
+// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
+// set in an IPv4 packet.
+func IPv4RouterAlert() NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ iterator := ip.Options().MakeIterator()
+ for {
+ opt, done, err := iterator.Next()
+ if err != nil {
+ t.Fatalf("error acquiring next IPv4 option %s", err)
+ }
+ if done {
+ break
+ }
+ if opt.Type() != header.IPv4OptionRouterAlertType {
+ continue
+ }
+ want := [header.IPv4OptionRouterAlertLength]byte{
+ byte(header.IPv4OptionRouterAlertType),
+ header.IPv4OptionRouterAlertLength,
+ header.IPv4OptionRouterAlertValue,
+ header.IPv4OptionRouterAlertValue,
+ }
+ if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
+ t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
+ }
+ return
+ }
+ t.Errorf("failed to find router alert option in %v", ip.Options())
+ }
+}
+
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -1012,6 +1066,74 @@ func ICMPv6Payload(want []byte) TransportChecker {
}
}
+// MLD creates a checker that checks that the packet contains a valid MLD
+// message for type of mldType, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// MLD message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.MessageBody()); got < minSize {
+ t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMaxRespDelay(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MaximumResponseDelay(); got != want {
+ t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// MLDMulticastAddress creates a checker that checks the Multicast Address
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMulticastAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MulticastAddress(); got != want {
+ t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
@@ -1031,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
+ if got := len(icmp.MessageBody()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
@@ -1065,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
@@ -1094,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
@@ -1112,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
@@ -1183,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
ndpOptions(t, na.Options(), opts)
}
}
@@ -1198,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ndpOptions(t, ns.Options(), opts)
}
}
@@ -1223,7 +1345,261 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ rs := header.NDPRouterSolicit(icmp.MessageBody())
ndpOptions(t, rs.Options(), opts)
}
}
+
+// IGMP checks the validity and properties of the given IGMP packet. It is
+// expected to be used in conjunction with other IGMP transport checkers for
+// specific properties.
+func IGMP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
+ }
+
+ igmp := header.IGMP(last.Payload())
+ for _, f := range checkers {
+ f(t, igmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// IGMPType creates a checker that checks the IGMP Type field.
+func IGMPType(want header.IGMPType) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.Type(); got != want {
+ t.Errorf("got igmp.Type() = %d, want = %d", got, want)
+ }
+ }
+}
+
+// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
+func IGMPMaxRespTime(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.MaxRespTime(); got != want {
+ t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
+func IGMPGroupAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.GroupAddress(); got != want {
+ t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index d87797617..0bdc12d53 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -11,11 +11,13 @@ go_library(
"gue.go",
"icmpv4.go",
"icmpv6.go",
+ "igmp.go",
"interfaces.go",
"ipv4.go",
"ipv6.go",
"ipv6_extension_headers.go",
"ipv6_fragment.go",
+ "mld.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
@@ -39,6 +41,8 @@ go_test(
size = "small",
srcs = [
"checksum_test.go",
+ "igmp_test.go",
+ "ipv4_test.go",
"ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
@@ -58,6 +62,7 @@ go_test(
srcs = [
"eth_test.go",
"ipv6_extension_headers_test.go",
+ "mld_test.go",
"ndp_test.go",
],
library = ":header",
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4303fc5d5..2eef64b4d 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -115,6 +115,12 @@ const (
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
+
+ // Multicast Listener Discovery (MLD) messages, see RFC 2710.
+
+ ICMPv6MulticastListenerQuery ICMPv6Type = 130
+ ICMPv6MulticastListenerReport ICMPv6Type = 131
+ ICMPv6MulticastListenerDone ICMPv6Type = 132
)
// IsErrorType returns true if the receiver is an ICMP error type.
@@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
-// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
-// packet's message body as defined by RFC 4443 section 2.1; the portion of the
-// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
-func (b ICMPv6) NDPPayload() []byte {
+// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
+// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go
new file mode 100644
index 000000000..5c5be1b9d
--- /dev/null
+++ b/pkg/tcpip/header/igmp.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// IGMP represents an IGMP header stored in a byte array.
+type IGMP []byte
+
+// IGMP implements `Transport`.
+var _ Transport = (*IGMP)(nil)
+
+const (
+ // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes,
+ // as per RFC 2236, Section 2, Page 2.
+ IGMPMinimumSize = 8
+
+ // IGMPQueryMinimumSize is the minimum size of a valid Membership Query
+ // Message in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPQueryMinimumSize = 8
+
+ // IGMPReportMinimumSize is the minimum size of a valid Report Message in
+ // bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPReportMinimumSize = 8
+
+ // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message
+ // in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPLeaveMessageMinimumSize = 8
+
+ // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page
+ // 3.
+ IGMPTTL = 1
+
+ // igmpTypeOffset defines the offset of the type field in an IGMP message.
+ igmpTypeOffset = 0
+
+ // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an
+ // IGMP message.
+ igmpMaxRespTimeOffset = 1
+
+ // igmpChecksumOffset defines the offset of the checksum field in an IGMP
+ // message.
+ igmpChecksumOffset = 2
+
+ // igmpGroupAddressOffset defines the offset of the Group Address field in an
+ // IGMP message.
+ igmpGroupAddressOffset = 4
+
+ // IGMPProtocolNumber is IGMP's transport protocol number.
+ IGMPProtocolNumber tcpip.TransportProtocolNumber = 2
+)
+
+// IGMPType is the IGMP type field as per RFC 2236.
+type IGMPType byte
+
+// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2.
+// Descriptions below come from there.
+const (
+ // IGMPMembershipQuery indicates that the message type is Membership Query.
+ // "There are two sub-types of Membership Query messages:
+ // - General Query, used to learn which groups have members on an
+ // attached network.
+ // - Group-Specific Query, used to learn if a particular group
+ // has any members on an attached network.
+ // These two messages are differentiated by the Group Address, as
+ // described in section 1.4 ."
+ IGMPMembershipQuery IGMPType = 0x11
+ // IGMPv1MembershipReport indicates that the message is a Membership Report
+ // generated by a host using the IGMPv1 protocol: "an additional type of
+ // message, for backwards-compatibility with IGMPv1"
+ IGMPv1MembershipReport IGMPType = 0x12
+ // IGMPv2MembershipReport indicates that the Message type is a Membership
+ // Report generated by a host using the IGMPv2 protocol.
+ IGMPv2MembershipReport IGMPType = 0x16
+ // IGMPLeaveGroup indicates that the message type is a Leave Group
+ // notification message.
+ IGMPLeaveGroup IGMPType = 0x17
+)
+
+// Type is the IGMP type field.
+func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) }
+
+// SetType sets the IGMP type field.
+func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) }
+
+// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership
+// Query messages, in other cases it is set to 0 by the sender and ignored by
+// the receiver.
+func (b IGMP) MaxRespTime() time.Duration {
+ // As per RFC 2236 section 2.2,
+ //
+ // The Max Response Time field is meaningful only in Membership Query
+ // messages, and specifies the maximum allowed time before sending a
+ // responding report in units of 1/10 second. In all other messages, it
+ // is set to zero by the sender and ignored by receivers.
+ return DecisecondToDuration(b[igmpMaxRespTimeOffset])
+}
+
+// SetMaxRespTime sets the MaxRespTimeField.
+func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m }
+
+// Checksum is the IGMP checksum field.
+func (b IGMP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[igmpChecksumOffset:])
+}
+
+// SetChecksum sets the IGMP checksum field.
+func (b IGMP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum)
+}
+
+// GroupAddress gets the Group Address field.
+func (b IGMP) GroupAddress() tcpip.Address {
+ return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize])
+}
+
+// SetGroupAddress sets the Group Address field.
+func (b IGMP) SetGroupAddress(address tcpip.Address) {
+ if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize))
+ }
+}
+
+// SourcePort implements Transport.SourcePort.
+func (IGMP) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (IGMP) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (IGMP) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (IGMP) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (IGMP) Payload() []byte {
+ return nil
+}
+
+// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP
+// header.
+func IGMPCalculateChecksum(h IGMP) uint16 {
+ // The header contains a checksum itself, set it aside to avoid checksumming
+ // the checksum and replace it afterwards.
+ existingXsum := h.Checksum()
+ h.SetChecksum(0)
+ xsum := ^Checksum(h, 0)
+ h.SetChecksum(existingXsum)
+ return xsum
+}
+
+// DecisecondToDuration converts a value representing deci-seconds to a
+// time.Duration.
+func DecisecondToDuration(ds uint8) time.Duration {
+ return time.Duration(ds) * time.Second / 10
+}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
new file mode 100644
index 000000000..b6126d29a
--- /dev/null
+++ b/pkg/tcpip/header/igmp_test.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestIGMPHeader tests the functions within header.igmp
+func TestIGMPHeader(t *testing.T) {
+ const maxRespTimeTenthSec = 0xF0
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ maxRespTimeTenthSec, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
+ }
+
+ igmpType := header.IGMPv2MembershipReport
+ igmpHeader.SetType(igmpType)
+ if got := igmpHeader.Type(); got != igmpType {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType)
+ }
+ if got := header.IGMPType(b[0]); got != igmpType {
+ t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType)
+ }
+
+ respTime := byte(0x02)
+ igmpHeader.SetMaxRespTime(respTime)
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ checksum := uint16(0x0102)
+ igmpHeader.SetChecksum(checksum)
+ if got := igmpHeader.Checksum(); got != checksum {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
+ }
+
+ groupAddress := tcpip.Address("\x04\x03\x02\x01")
+ igmpHeader.SetGroupAddress(groupAddress)
+ if got := igmpHeader.GroupAddress(); got != groupAddress {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
+ }
+}
+
+// TestIGMPChecksum ensures that the checksum calculator produces the expected
+// checksum.
+func TestIGMPChecksum(t *testing.T) {
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ 0xF0, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ // Calculate the initial checksum after setting the checksum temporarily to 0
+ // to avoid checksumming the checksum.
+ initialChecksum := igmpHeader.Checksum()
+ igmpHeader.SetChecksum(0)
+ checksum := ^header.Checksum(b, 0)
+ igmpHeader.SetChecksum(initialChecksum)
+
+ if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum {
+ t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum)
+ }
+}
+
+func TestDecisecondToDuration(t *testing.T) {
+ const valueInDeciseconds = 5
+ if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want {
+ t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 91fe7b6a5..e6103f4bc 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -100,7 +100,7 @@ type IPv4Fields struct {
//
// That leaves ten 32 bit (4 byte) fields for options. An attempt to encode
// more will fail.
- Options IPv4Options
+ Options IPv4OptionsSerializer
}
// IPv4 is an IPv4 header.
@@ -157,6 +157,9 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+ // IPv4AllRoutersGroup is a multicast address for all routers.
+ IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02"
+
// IPv4MinimumProcessableDatagramSize is the minimum size of an IP
// packet that every IPv4 capable host must be able to
// process/reassemble.
@@ -282,18 +285,17 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// IPv4Options is a buffer that holds all the raw IP options.
-type IPv4Options []byte
-
-// SizeWithPadding implements stack.NetOptions.
-// It reports the size to allocate for the Options. RFC 791 page 23 (end of
-// section 3.1) says of the padding at the end of the options:
+// padIPv4OptionsLength returns the total length for IPv4 options of length l
+// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
// header ends on a 32 bit boundary.
-func (o IPv4Options) SizeWithPadding() int {
- return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1)
+func padIPv4OptionsLength(length uint8) uint8 {
+ return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1)
}
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
// Options returns a buffer holding the options.
func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
@@ -372,26 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 {
func (b IPv4) Encode(i *IPv4Fields) {
// The size of the options defines the size of the whole header and thus the
// IHL field. Options are rare and this is a heavily used function so it is
- // worth a bit of optimisation here to keep the copy out of the fast path.
- hdrLen := IPv4MinimumSize
+ // worth a bit of optimisation here to keep the serializer out of the fast
+ // path.
+ hdrLen := uint8(IPv4MinimumSize)
if len(i.Options) != 0 {
- // SizeWithPadding is always >= len(i.Options).
- aLen := i.Options.SizeWithPadding()
- hdrLen += aLen
- if hdrLen > len(b) {
- panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen))
- }
- opts := b[options:]
- // This avoids bounds checks on the next line(s) which would happen even
- // if there's no work to do.
- if n := copy(opts, i.Options); n != aLen {
- padding := opts[n:][:aLen-n]
- for i := range padding {
- padding[i] = 0
- }
- }
+ hdrLen += i.Options.Serialize(b[options:])
+ }
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize))
}
- b.SetHeaderLength(uint8(hdrLen))
+ b.SetHeaderLength(hdrLen)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
@@ -471,6 +463,10 @@ const (
// options and may appear multiple times.
IPv4OptionNOPType IPv4OptionType = 1
+ // IPv4OptionRouterAlertType is the option type for the Router Alert option,
+ // defined in RFC 2113 Section 2.1.
+ IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80
+
// IPv4OptionRecordRouteType is used by each router on the path of the packet
// to record its path. It is carried over to an Echo Reply.
IPv4OptionRecordRouteType IPv4OptionType = 7
@@ -871,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+
+// Router Alert option specific related constants.
+//
+// from RFC 2113 section 2.1:
+//
+// +--------+--------+--------+--------+
+// |10010100|00000100| 2 octet value |
+// +--------+--------+--------+--------+
+//
+// Type:
+// Copied flag: 1 (all fragments must carry the option)
+// Option class: 0 (control)
+// Option number: 20 (decimal)
+//
+// Length: 4
+//
+// Value: A two octet code with the following values:
+// 0 - Router shall examine packet
+// 1-65535 - Reserved
+const (
+ // IPv4OptionRouterAlertLength is the length of a Router Alert option.
+ IPv4OptionRouterAlertLength = 4
+
+ // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit
+ // payload of the router alert option.
+ IPv4OptionRouterAlertValue = 0
+
+ // iPv4OptionRouterAlertValueOffset is the offset for the value of a
+ // RouterAlert option.
+ iPv4OptionRouterAlertValueOffset = 2
+)
+
+// IPv4SerializableOption is an interface to represent serializable IPv4 option
+// types.
+type IPv4SerializableOption interface {
+ // optionType returns the type identifier of the option.
+ optionType() IPv4OptionType
+}
+
+// IPv4SerializableOptionPayload is an interface providing serialization of the
+// payload of an IPv4 option.
+type IPv4SerializableOptionPayload interface {
+ // length returns the size of the payload.
+ length() uint8
+
+ // serializeInto serializes the payload into the provided byte buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MUST panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto(buffer []byte) uint8
+}
+
+// IPv4OptionsSerializer is a serializer for IPv4 options.
+type IPv4OptionsSerializer []IPv4SerializableOption
+
+// Length returns the total number of bytes required to serialize the options.
+func (s IPv4OptionsSerializer) Length() uint8 {
+ var total uint8
+ for _, opt := range s {
+ total++
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Add 1 to reported length to account for the length byte.
+ total += 1 + withPayload.length()
+ }
+ }
+ return padIPv4OptionsLength(total)
+}
+
+// Serialize serializes the provided list of IPV4 options into b.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// IPv4OptionsSerializer.Length for details on the getting the total size
+// of a serialized IPv4OptionsSerializer.
+//
+// Serialize panics if b is not of sufficient size to hold all the options in s.
+func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 {
+ var total uint8
+ for _, opt := range s {
+ ty := opt.optionType()
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Serialize first to reduce bounds checks.
+ l := 2 + withPayload.serializeInto(b[2:])
+ b[0] = byte(ty)
+ b[1] = l
+ b = b[l:]
+ total += l
+ continue
+ }
+ // Options without payload consist only of the type field.
+ //
+ // NB: Repeating code from the branch above is intentional to minimize
+ // bounds checks.
+ b[0] = byte(ty)
+ b = b[1:]
+ total++
+ }
+
+ // According to RFC 791:
+ //
+ // The internet header padding is used to ensure that the internet
+ // header ends on a 32 bit boundary. The padding is zero.
+ padded := padIPv4OptionsLength(total)
+ b = b[:padded-total]
+ for i := range b {
+ b[i] = 0
+ }
+ return padded
+}
+
+var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil)
+var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil)
+
+// IPv4SerializableRouterAlertOption provides serialization of the Router Alert
+// IPv4 option according to RFC 2113.
+type IPv4SerializableRouterAlertOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType {
+ return IPv4OptionRouterAlertType
+}
+
+// Length implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) length() uint8 {
+ return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset
+}
+
+// SerializeInto implements IPv4SerializableOption.
+func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 {
+ binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue)
+ return o.length()
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil)
+
+// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option.
+type IPv4SerializableNOPOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableNOPOption) optionType() IPv4OptionType {
+ return IPv4OptionNOPType
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil)
+
+// IPv4SerializableListEndOption provides serialization for the IPv4 List End
+// option.
+type IPv4SerializableListEndOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableListEndOption) optionType() IPv4OptionType {
+ return IPv4OptionListEndType
+}
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
new file mode 100644
index 000000000..6475cd694
--- /dev/null
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -0,0 +1,179 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4OptionsSerializer(t *testing.T) {
+ optCases := []struct {
+ name string
+ option []header.IPv4SerializableOption
+ expect []byte
+ }{
+ {
+ name: "NOP",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ },
+ expect: []byte{1, 0, 0, 0},
+ },
+ {
+ name: "ListEnd",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableListEndOption{},
+ },
+ expect: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{148, 4, 0, 0},
+ }, {
+ name: "NOP and RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{1, 148, 4, 0, 0, 0, 0, 0},
+ },
+ }
+
+ for _, opt := range optCases {
+ t.Run(opt.name, func(t *testing.T) {
+ s := header.IPv4OptionsSerializer(opt.option)
+ l := s.Length()
+ if got := len(opt.expect); got != int(l) {
+ t.Fatalf("s.Length() = %d, want = %d", got, l)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with full bytes to ensure padding is being set
+ // correctly.
+ b[i] = 0xFF
+ }
+ if serializedLength := s.Serialize(b); serializedLength != l {
+ t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l)
+ }
+ if diff := cmp.Diff(opt.expect, b); diff != "" {
+ t.Errorf("mismatched serialized option (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ numberOfNops int
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ numberOfNops: 1,
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ numberOfNops: 2,
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ numberOfNops: 3,
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ numberOfNops: 4,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ numberOfNops: 5,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ numberOfNops: 39,
+ encodedOptions: header.IPv4Options{
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops))
+ for i := range serializeOpts {
+ serializeOpts[i] = &header.IPv4SerializableNOPOption{}
+ }
+ paddedOptionLength := serializeOpts.Length()
+ ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength)
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != serializeOpts.Length() {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: serializeOpts,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 55d09355a..d522e5f10 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -48,11 +48,13 @@ type IPv6Fields struct {
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
- // PayloadLength is the "payload length" field of an IPv6 packet.
+ // PayloadLength is the "payload length" field of an IPv6 packet, including
+ // the length of all extension headers.
PayloadLength uint16
- // NextHeader is the "next header" field of an IPv6 packet.
- NextHeader uint8
+ // TransportProtocol is the transport layer protocol number. Serialized in the
+ // last "next header" field of the IPv6 header + extension headers.
+ TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
@@ -62,6 +64,9 @@ type IPv6Fields struct {
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
+
+ // ExtensionHeaders are the extension headers following the IPv6 header.
+ ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
@@ -253,12 +258,14 @@ func (IPv6) SetChecksum(uint16) {
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
+ extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
+ nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
+ b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 583c2c5d3..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
+ "math"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -47,6 +50,11 @@ const (
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+
+ // IPv6UnknownExtHdrIdentifier is reserved by IANA.
+ // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
+ // "254 Use for experimentation and testing [RFC3692][RFC4727]"
+ IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
)
const (
@@ -70,8 +78,8 @@ const (
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
- // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
- // discard from the Fragment Offset.
+ // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
+ // Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
@@ -109,6 +117,37 @@ const (
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
+// padIPv6OptionsLength returns the total length for IPv6 options of length l
+// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
+func padIPv6OptionsLength(length int) int {
+ return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
+}
+
+// padIPv6Option fills b with the appropriate padding options depending on its
+// length.
+func padIPv6Option(b []byte) {
+ switch len(b) {
+ case 0: // No padding needed.
+ case 1: // Pad with Pad1.
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
+ default: // Pad with PadN.
+ s := b[ipv6ExtHdrOptionPayloadOffset:]
+ for i := range s {
+ s[i] = 0
+ }
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
+ b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
+ }
+}
+
+// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
+// serialize an option at headerOffset with alignment requirements
+// [align]n + alignOffset.
+func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
+ padLen := headerOffset - alignOffset
+ return ((padLen + align - 1) & ^(align - 1)) - padLen
+}
+
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
@@ -201,29 +240,55 @@ type IPv6ExtHdrOption interface {
isIPv6ExtHdrOption()
}
-// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
-type IPv6ExtHdrOptionIndentifier uint8
+// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
- ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
- ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
+
+ // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
+ // Alert Hop by Hop option as defined in RFC 2711 section 2.1.
+ ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
+
+ // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
+ // option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionTypeOffset = 0
+
+ // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionLengthOffset = 1
+
+ // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionPayloadOffset = 2
)
+// ipv6UnknownActionFromIdentifier maps an extension header option's
+// identifier's high bits to the action to take when the identifier is unknown.
+func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
- Identifier IPv6ExtHdrOptionIndentifier
+ Identifier IPv6ExtHdrOptionIdentifier
Data []byte
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
- return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+ return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
@@ -246,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
- id := IPv6ExtHdrOptionIndentifier(temp)
+ id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
@@ -289,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
continue
+ case ipv6RouterAlertHopByHopOptionIdentifier:
+ var routerAlertValue [ipv6RouterAlertPayloadLength]byte
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ }
+ return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := make([]byte, length)
if n, err := io.ReadFull(&i.reader, bytes); err != nil {
@@ -452,9 +530,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Since we consume the iterator, we return the payload as is.
buf = i.payload
- // Mark i as done.
+ // Mark i as done, but keep track of where we were for error reporting.
*i = IPv6PayloadIterator{
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ headerOffset: i.headerOffset,
+ nextOffset: i.nextOffset,
}
} else {
buf = i.payload.Clone(nil)
@@ -602,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
}
+
+// IPv6SerializableExtHdr provides serialization for IPv6 extension
+// headers.
+type IPv6SerializableExtHdr interface {
+ // identifier returns the assigned IPv6 header identifier for this extension
+ // header.
+ identifier() IPv6ExtensionHeaderIdentifier
+
+ // length returns the total serialized length in bytes of this extension
+ // header, including the common next header and length fields.
+ length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer and with the provided nextHeader value.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto returns the number of bytes that was used to serialize the
+ // receiver. Implementers must only use the number of bytes required to
+ // serialize the receiver. Callers MAY provide a larger buffer than required
+ // to serialize into.
+ serializeInto(nextHeader uint8, b []byte) int
+}
+
+var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
+
+// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
+// options extension header.
+type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
+
+const (
+ // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
+ // in a hop by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrNextHeaderOffset = 0
+
+ // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
+ // by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrLengthOffset = 1
+
+ // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
+ // hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrOptionsOffset = 2
+
+ // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
+ // words in a hop by hop extension header's length field, as stated in RFC
+ // 8200 section 4.3:
+ // Length of the Hop-by-Hop Options header in 8-octet units,
+ // not including the first 8 octets.
+ ipv6HopByHopExtHdrUnaccountedLenWords = 1
+)
+
+// identifier implements IPv6SerializableExtHdr.
+func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6HopByHopOptionsExtHdrIdentifier
+}
+
+// length implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) length() int {
+ var total int
+ for _, opt := range h {
+ align, alignOffset := opt.alignment()
+ total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
+ total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
+ }
+ // Account for next header and total length fields and add padding.
+ return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
+ totalLength := ipv6HopByHopExtHdrOptionsOffset
+ for _, opt := range h {
+ // Calculate alignment requirements and pad buffer if necessary.
+ align, alignOffset := opt.alignment()
+ padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
+ if padLen != 0 {
+ padIPv6Option(optBuffer[:padLen])
+ totalLength += padLen
+ optBuffer = optBuffer[padLen:]
+ }
+
+ l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
+ optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
+ optBuffer[ipv6ExtHdrOptionLengthOffset] = l
+ l += ipv6ExtHdrOptionPayloadOffset
+ totalLength += int(l)
+ optBuffer = optBuffer[l:]
+ }
+ padded := padIPv6OptionsLength(totalLength)
+ if padded != totalLength {
+ padIPv6Option(optBuffer[:padded-totalLength])
+ totalLength = padded
+ }
+ wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
+ if wordsLen > math.MaxUint8 {
+ panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
+ }
+ b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
+ b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
+ return totalLength
+}
+
+// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
+type IPv6SerializableHopByHopOption interface {
+ // identifier returns the option identifier of this Hop by Hop option.
+ identifier() IPv6ExtHdrOptionIdentifier
+
+ // length returns the *payload* size of the option (not considering the type
+ // and length fields).
+ length() uint8
+
+ // alignment returns the alignment requirements from this option.
+ //
+ // Alignment requirements take the form [align]n + offset as specified in
+ // RFC 8200 section 4.2. The alignment requirement is on the offset between
+ // the option type byte and the start of the hop by hop header.
+ //
+ // align must be a power of 2.
+ alignment() (align int, offset int)
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) uint8
+}
+
+var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
+
+// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
+// RFC 2711 section 2.1.
+type IPv6RouterAlertOption struct {
+ Value IPv6RouterAlertValue
+}
+
+// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
+type IPv6RouterAlertValue uint16
+
+const (
+ // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
+ // Discovery message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertMLD IPv6RouterAlertValue = 0
+ // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
+ // defined in RFC 2711 section 2.1.
+ IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
+ // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
+ // Networks message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
+
+ // ipv6RouterAlertPayloadLength is the length of the Router Alert payload
+ // as defined in RFC 2711.
+ ipv6RouterAlertPayloadLength = 2
+
+ // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
+ // Router Alert option defined as 2n+0 in RFC 2711.
+ ipv6RouterAlertAlignmentRequirement = 2
+
+ // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
+ // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
+ // 2.1.
+ ipv6RouterAlertAlignmentOffsetRequirement = 0
+)
+
+// UnknownAction implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
+ return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
+ return ipv6RouterAlertHopByHopOptionIdentifier
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) length() uint8 {
+ return ipv6RouterAlertPayloadLength
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) alignment() (int, int) {
+ // From RFC 2711 section 2.1:
+ // Alignment requirement: 2n+0.
+ return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
+ binary.BigEndian.PutUint16(b, uint16(o.Value))
+ return ipv6RouterAlertPayloadLength
+}
+
+// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
+type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
+
+// Serialize serializes the provided list of IPv6 extension headers into b.
+//
+// Note, b must be of sufficient size to hold all the headers in s. See
+// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
+// serialized IPv6ExtHdrSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+//
+// Serialize takes the transportProtocol value to be used as the last extension
+// header's Next Header value and returns the header identifier of the first
+// serialized extension header and the total serialized length.
+func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
+ nextHeader := uint8(transportProtocol)
+ if len(s) == 0 {
+ return nextHeader, 0
+ }
+ var totalLength int
+ for i, h := range s[:len(s)-1] {
+ length := h.serializeInto(uint8(s[i+1].identifier()), b)
+ b = b[length:]
+ totalLength += length
+ }
+ totalLength += s[len(s)-1].serializeInto(nextHeader, b)
+ return uint8(s[0].identifier()), totalLength
+}
+
+// Length returns the total number of bytes required to serialize the extension
+// headers.
+func (s IPv6ExtHdrSerializer) Length() int {
+ var totalLength int
+ for _, h := range s {
+ totalLength += h.length()
+ }
+ return totalLength
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index ab20c5f37..65adc6250 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool
func TestIPv6UnknownExtHdrOption(t *testing.T) {
tests := []struct {
name string
- identifier IPv6ExtHdrOptionIndentifier
+ identifier IPv6ExtHdrOptionIdentifier
expectedUnknownAction IPv6OptionUnknownAction
}{
{
@@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
bytes: []byte{1, 3},
err: io.ErrUnexpectedEOF,
},
+ {
+ name: "Router alert without data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data and Pad1",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with extra data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with missing data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
+ err: io.ErrUnexpectedEOF,
+ },
}
check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
@@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) {
})
}
}
+
+var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
+
+// dummyHbHOptionSerializer provides a generic implementation of
+// IPv6SerializableHopByHopOption for use in tests.
+type dummyHbHOptionSerializer struct {
+ id IPv6ExtHdrOptionIdentifier
+ payload []byte
+ align int
+ alignOffset int
+}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) length() uint8 {
+ return uint8(len(s.payload))
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) alignment() (int, int) {
+ align := 1
+ if s.align != 0 {
+ align = s.align
+ }
+ return align, s.alignOffset
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
+ return uint8(copy(b, s.payload))
+}
+
+func TestIPv6HopByHopSerializer(t *testing.T) {
+ validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ dummy, ok := serializable.(*dummyHbHOptionSerializer)
+ if !ok {
+ t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
+ }
+ unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
+ }
+ if dummy.id != unknown.Identifier {
+ t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
+ }
+ if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
+ t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
+ }
+ }
+ tests := []struct {
+ name string
+ nextHeader uint8
+ options []IPv6SerializableHopByHopOption
+ expect []byte
+ validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
+ }{
+ {
+ name: "single option",
+ nextHeader: 13,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 15,
+ payload: []byte{9, 8, 7, 6},
+ },
+ },
+ expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
+ validate: validateDummies,
+ },
+ {
+ name: "short option padN zero",
+ nextHeader: 88,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5},
+ },
+ },
+ expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "short option pad1",
+ nextHeader: 11,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 33,
+ payload: []byte{1, 2, 3},
+ },
+ },
+ expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "long option padN",
+ nextHeader: 55,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 77,
+ payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ },
+ expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 2n",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 2,
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 8n+1",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 8,
+ alignOffset: 1,
+ },
+ },
+ expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "no options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{},
+ expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
+ },
+ {
+ name: "Router Alert",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
+ expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
+ validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
+ }
+ if routerAlert.Value != IPv6RouterAlertMLD {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6SerializableHopByHopExtHdr(test.options)
+ length := s.length()
+ if length != len(test.expect) {
+ t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
+ }
+ b := make([]byte, length)
+ for i := range b {
+ // Fill the buffer with ones to ensure all padding is correctly set.
+ b[i] = 0xFF
+ }
+ if got := s.serializeInto(test.nextHeader, b); got != length {
+ t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
+ }
+ if diff := cmp.Diff(test.expect, b); diff != "" {
+ t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
+ }
+
+ // Deserialize the options and verify them.
+ optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
+ iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
+ for _, testOpt := range test.options {
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ test.validate(t, testOpt, opt)
+ }
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if !done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ })
+ }
+}
+
+var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
+
+// dummyIPv6ExtHdrSerializer provides a generic implementation of
+// IPv6SerializableExtHdr for use in tests.
+//
+// The dummy header always carries the nextHeader value in the first byte.
+type dummyIPv6ExtHdrSerializer struct {
+ id IPv6ExtensionHeaderIdentifier
+ headerContents []byte
+}
+
+// identifier implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) length() int {
+ return len(s.headerContents) + 1
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
+ b[0] = nextHeader
+ return copy(b[1:], s.headerContents) + 1
+}
+
+func TestIPv6ExtHdrSerializer(t *testing.T) {
+ tests := []struct {
+ name string
+ headers []IPv6SerializableExtHdr
+ nextHeader tcpip.TransportProtocolNumber
+ expectSerialized []byte
+ expectNextHeader uint8
+ }{
+ {
+ name: "one header",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 15,
+ headerContents: []byte{1, 2, 3, 4},
+ },
+ },
+ nextHeader: TCPProtocolNumber,
+ expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
+ expectNextHeader: 15,
+ },
+ {
+ name: "two headers",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 22,
+ headerContents: []byte{1, 2, 3},
+ },
+ &dummyIPv6ExtHdrSerializer{
+ id: 23,
+ headerContents: []byte{4, 5, 6},
+ },
+ },
+ nextHeader: ICMPv6ProtocolNumber,
+ expectSerialized: []byte{
+ 23, 1, 2, 3,
+ byte(ICMPv6ProtocolNumber), 4, 5, 6,
+ },
+ expectNextHeader: 22,
+ },
+ {
+ name: "no headers",
+ headers: []IPv6SerializableExtHdr{},
+ nextHeader: UDPProtocolNumber,
+ expectSerialized: []byte{},
+ expectNextHeader: byte(UDPProtocolNumber),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6ExtHdrSerializer(test.headers)
+ l := s.Length()
+ if got, want := l, len(test.expectSerialized); got != want {
+ t.Fatalf("got serialized length = %d, want = %d", got, want)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with garbage to make sure we're writing to all bytes.
+ b[i] = 0xFF
+ }
+ nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
+ if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
+ t.Errorf(
+ "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
+ nextHeader,
+ serializedLen,
+ test.expectNextHeader,
+ len(test.expectSerialized),
+ )
+ }
+ if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
+ t.Errorf("serialization mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
index 018555a26..9d09f32eb 100644
--- a/pkg/tcpip/header/ipv6_fragment.go
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -27,12 +27,11 @@ const (
idV6 = 4
)
-// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
-// fields of a packet that needs to be encoded.
-type IPv6FragmentFields struct {
- // NextHeader is the "next header" field of an IPv6 fragment.
- NextHeader uint8
+var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
+// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
+// extension header as defined in RFC 8200 section 4.5.
+type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
@@ -43,6 +42,29 @@ type IPv6FragmentFields struct {
Identification uint32
}
+// identifier implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6FragmentHeader
+}
+
+// length implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) length() int {
+ return IPv6FragmentHeaderSize
+}
+
+// serializeInto implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ // Prevent too many bounds checks.
+ _ = b[IPv6FragmentHeaderSize:]
+ binary.BigEndian.PutUint32(b[idV6:], h.Identification)
+ binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
+ b[nextHdrFrag] = nextHeader
+ if h.M {
+ b[more] |= ipv6FragmentExtHdrMFlagMask
+ }
+ return IPv6FragmentHeaderSize
+}
+
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
@@ -58,16 +80,6 @@ const (
IPv6FragmentHeaderSize = 8
)
-// Encode encodes all the fields of the ipv6 fragment.
-func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
- b[nextHdrFrag] = i.NextHeader
- binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
- if i.M {
- b[more] |= 1
- }
- binary.BigEndian.PutUint32(b[idV6:], i.Identification)
-}
-
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go
new file mode 100644
index 000000000..ffe03c76a
--- /dev/null
+++ b/pkg/tcpip/header/mld.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // MLDMinimumSize is the minimum size for an MLD message.
+ MLDMinimumSize = 20
+
+ // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as
+ // per RFC 2710 section 3.
+ MLDHopLimit = 1
+
+ // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay
+ // field within MLD.
+ mldMaximumResponseDelayOffset = 0
+
+ // mldMulticastAddressOffset is the offset to the Multicast Address field
+ // within MLD.
+ mldMulticastAddressOffset = 4
+)
+
+// MLD is a Multicast Listener Discovery message in an ICMPv6 packet.
+//
+// MLD will only contain the body of an ICMPv6 packet.
+//
+// As per RFC 2710 section 3, MLD messages have the following format (MLD only
+// holds the bytes after the first four bytes in the diagram below):
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Maximum Response Delay | Reserved |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | |
+// + +
+// | |
+// + Multicast Address +
+// | |
+// + +
+// | |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+type MLD []byte
+
+// MaximumResponseDelay returns the Maximum Response Delay.
+func (m MLD) MaximumResponseDelay() time.Duration {
+ // As per RFC 2710 section 3.4:
+ //
+ // The Maximum Response Delay field is meaningful only in Query
+ // messages, and specifies the maximum allowed delay before sending a
+ // responding Report, in units of milliseconds. In all other messages,
+ // it is set to zero by the sender and ignored by receivers.
+ return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond
+}
+
+// SetMaximumResponseDelay sets the Maximum Response Delay field.
+//
+// maxRespDelayMS is the value in milliseconds.
+func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) {
+ binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS)
+}
+
+// MulticastAddress returns the Multicast Address.
+func (m MLD) MulticastAddress() tcpip.Address {
+ // As per RFC 2710 section 3.5:
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ //
+ // In a Report or Done message, the Multicast Address field holds a
+ // specific IPv6 multicast address to which the message sender is
+ // listening or is ceasing to listen, respectively.
+ return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize])
+}
+
+// SetMulticastAddress sets the Multicast Address field.
+func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) {
+ if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize))
+ }
+}
diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go
new file mode 100644
index 000000000..0cecf10d4
--- /dev/null
+++ b/pkg/tcpip/header/mld_test.go
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestMLD(t *testing.T) {
+ b := []byte{
+ // Maximum Response Delay
+ 0, 0,
+
+ // Reserved
+ 0, 0,
+
+ // MulticastAddress
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
+ }
+
+ const maxRespDelay = 513
+ binary.BigEndian.PutUint16(b, maxRespDelay)
+
+ mld := MLD(b)
+
+ if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ const newMaxRespDelay = 1234
+ mld.SetMaximumResponseDelay(newMaxRespDelay)
+ if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want)
+ }
+
+ multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})
+ mld.SetMulticastAddress(multicastAddress)
+ if got := mld.MulticastAddress(); got != multicastAddress {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress)
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 5d3975c56..554242f0c 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
return it, nil
}
-// Serialize serializes the provided list of NDP options into o.
+// Serialize serializes the provided list of NDP options into b.
//
// Note, b must be of sufficient size to hold all the options in s. See
// NDPOptionsSerializer.Length for details on the getting the total size
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 98bdd29db..a6d4fcd59 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
- // SrcPort is the "source port" field of a UDP packet.
+ // SrcPort is the "Source Port" field of a UDP packet.
SrcPort uint16
- // DstPort is the "destination port" field of a UDP packet.
+ // DstPort is the "Destination Port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
@@ -64,52 +64,57 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
+// SourcePort returns the "Source Port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "destination port" field of the udp header.
+// DestinationPort returns the "Destination Port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "length" field of the udp header.
+// Length returns the "Length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
- return b[UDPMinimumSize:]
+ return b[:b.Length()][UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the udp header.
+// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the udp header.
+// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the udp header.
+// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the udp header.
+// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the udp header.
+// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the
+// PayloadLength returns the length of the payload following the UDP header.
+func (b UDP) PayloadLength() uint16 {
+ return b.Length() - UDPMinimumSize
+}
+
+// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 39ca774ef..973f06cbc 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -9,7 +9,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c95aef63c..0efbfb22b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -32,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route stack.Route
+ Route *stack.Route
}
// Notification is the interface for receiving notification from the packet
@@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := PacketInfo{
- Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }),
- Proto: 0,
- GSO: nil,
- }
-
- e.q.Write(p)
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 3eef7cd56..beefcd008 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
@@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 975309fc8..cb94cbea6 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // enable PACKET_FANOUT mode is the underlying socket is
- // of type AF_PACKET.
- const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ // Enable PACKET_FANOUT mode if the underlying socket is of type
+ // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
+ // prevent gvisor from receiving fragmented packets and the host does the
+ // reassembly on our behalf before delivering the fragments. This makes it
+ // hard to test fragmentation reassembly code in Netstack.
+ const fanoutType = unix.PACKET_FANOUT_HASH
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
@@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
}
var builder iovec.Builder
@@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
-}
-
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 709f829c8..ce4da7230 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- }
+ var r stack.Route
+ r.ResolveWith(raddr)
// Build payload.
payload := buffer.NewView(plen)
@@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -325,9 +324,9 @@ func TestPreserveSrcAddress(t *testing.T) {
// Set LocalLinkAddress in route to the value of the bridged address.
r := &stack.Route{
- RemoteLinkAddress: raddr,
- LocalLinkAddress: baddr,
+ LocalLinkAddress: baddr,
}
+ r.ResolveWith(raddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 38aa694e4..edca57e4e 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- // There should be an ethernet header at the beginning of vv.
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- // Reject the packet if it's shorter than an ethernet header.
- return tcpip.ErrBadAddress
- }
- linkHeader := header.Ethernet(hdr)
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
-
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index e7493e5c5..cbda59775 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 56a611825..22e79ce3a 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -17,7 +17,6 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco
return tcpip.ErrNoRoute
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // WriteRawPacket doesn't get a route or network address, so there's
- // nowhere to write this.
- return tcpip.ErrNoRoute
-}
-
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index 2cdb23475..00b42b924 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index d40de54df..0ee54c3d5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -19,7 +19,6 @@ package nested
import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return e.child.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return e.child.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.
func (e *Endpoint) Wait() {
e.child.Wait()
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index 3922c2a04..9a1b0c0c2 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
// WritePacket implements stack.LinkEndpoint.WritePacket.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 523b0d24b..25c364391 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
// endpoint is the local address on the other end.
- e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
@@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList,
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- panic("not implemented")
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 1d0079bd6..5bea598eb 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index fc1e34fc7..27667f5f0 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -156,7 +155,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
newRoute := r.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = newRoute
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -183,7 +182,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// the route here to ensure it doesn't get released while the
// packet is still in our queue.
newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
@@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
return enqueued, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // TODO(gvisor.dev/issue/3267): Queue these packets as well once
- // WriteRawPacket takes PacketBuffer instead of VectorisedView.
- return e.lower.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() {
e.lower.Wait()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 7fb8a6c49..5660418fa 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
views := pkt.Views()
// Transmit the packet.
@@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- views := vv.Views()
- // Transmit the packet.
- e.mu.Lock()
- ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
- if !ok {
- return tcpip.ErrWouldBlock
- }
-
- return nil
-}
-
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 22d5c97f1..7131392cc 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) {
defer c.cleanup()
// Prepare route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
for iters := 1000; iters > 0; iters-- {
func() {
@@ -342,9 +341,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
// Set both remote and local link address in route.
r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- LocalLinkAddress: newLocalLinkAddress,
+ LocalLinkAddress: newLocalLinkAddress,
}
+ r.ResolveWith(remoteLinkAddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -395,9 +394,8 @@ func TestFillTxQueue(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -444,9 +442,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
c.txq.rx.Flush()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -509,9 +506,8 @@ func TestFillTxMemory(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -557,9 +553,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b3e8c4b92..8d9a91020 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -53,16 +53,35 @@ type endpoint struct {
nested.Endpoint
writer io.Writer
maxPCAPLen uint32
+ logPrefix string
}
var _ stack.GSOEndpoint = (*endpoint)(nil)
var _ stack.LinkEndpoint = (*endpoint)(nil)
var _ stack.NetworkDispatcher = (*endpoint)(nil)
+type direction int
+
+const (
+ directionSend = iota
+ directionRecv
+)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- sniffer := &endpoint{}
+ return NewWithPrefix(lower, "")
+}
+
+// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets prefixed with logPrefix as they traverse
+// the endpoint.
+//
+// logPrefix is prepended to the log line without any separators.
+// E.g. logPrefix = "NIC:en0/" will produce log lines like
+// "NIC:en0/send udp [...]".
+func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint {
+ sniffer := &endpoint{logPrefix: logPrefix}
sniffer.Endpoint.Init(lower, sniffer)
return sniffer
}
@@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, pkt)
+ e.dumpPacket(directionRecv, nil, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// forwards the request to the lower endpoint.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return e.Endpoint.WriteRawPacket(vv)
-}
-
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var fragmentOffset uint16
var moreFragments bool
+ var directionPrefix string
+ switch dir {
+ case directionSend:
+ directionPrefix = "send"
+ case directionRecv:
+ directionPrefix = "recv"
+ default:
+ panic(fmt.Sprintf("unrecognized direction: %d", dir))
+ }
+
// Clone the packet buffer to not modify the original.
//
// We don't clone the original packet buffer so that the new packet buffer
@@ -248,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
- "%s arp %s (%s) -> %s (%s) valid:%t",
+ "%s%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
+ directionPrefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
- log.Infof("%s unknown network protocol", prefix)
+ log.Infof("%s%s unknown network protocol", prefix, directionPrefix)
return
}
@@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto)
return
}
@@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 9a76bdba7..a364c5801 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" {
return nil, false
}
@@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader().View().IsEmpty() {
- d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt)
}
vv.AppendView(info.Pkt.LinkHeader().View())
}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index ee84c3d96..9b4602c1b 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/gate",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
@@ -25,7 +24,6 @@ go_test(
library = ":waitable",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index b152a0f26..cf0077f43 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -24,7 +24,6 @@ package waitable
import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, err
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if !e.writeGate.Enter() {
- return nil
- }
-
- err := e.lower.WriteRawPacket(vv)
- e.writeGate.Leave()
- return err
-}
-
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 94827fc56..cf7fb5126 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -18,7 +18,6 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.
return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index b38aff0b8..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,12 +7,14 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index f462524c9..0fb373612 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -319,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
copy(h.HardwareAddressSender(), test.senderLinkAddr)
copy(h.ProtocolAddressSender(), test.senderAddr)
copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
if !test.isValid {
// No packets should be sent after receiving an invalid ARP request.
@@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d8e4a3b54..429af69ee 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -18,7 +18,6 @@ go_template_instance(
go_library(
name = "fragmentation",
srcs = [
- "frag_heap.go",
"fragmentation.go",
"reassembler.go",
"reassembler_list.go",
@@ -38,7 +37,6 @@ go_test(
name = "fragmentation_test",
size = "small",
srcs = [
- "frag_heap_test.go",
"fragmentation_test.go",
"reassembler_test.go",
],
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
deleted file mode 100644
index 0b570d25a..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-type fragment struct {
- offset uint16
- vv buffer.VectorisedView
-}
-
-type fragHeap []fragment
-
-func (h *fragHeap) Len() int {
- return len(*h)
-}
-
-func (h *fragHeap) Less(i, j int) bool {
- return (*h)[i].offset < (*h)[j].offset
-}
-
-func (h *fragHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *fragHeap) Push(x interface{}) {
- *h = append(*h, x.(fragment))
-}
-
-func (h *fragHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
-}
-
-// reassamble empties the heap and returns a VectorisedView
-// containing a reassambled version of the fragments inside the heap.
-func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
- curr := heap.Pop(h).(fragment)
- views := curr.vv.Views()
- size := curr.vv.Size()
-
- if curr.offset != 0 {
- return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
- }
-
- for h.Len() > 0 {
- curr := heap.Pop(h).(fragment)
- if int(curr.offset) < size {
- curr.vv.TrimFront(size - int(curr.offset))
- } else if int(curr.offset) > size {
- return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
- }
- size += curr.vv.Size()
- views = append(views, curr.vv.Views()...)
- }
- return buffer.NewVectorisedView(size, views), nil
-}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index c75ca7d71..1af87d713 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -46,9 +46,17 @@ const (
)
var (
- // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 19f4920b3..9b20bb1d8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -15,9 +15,8 @@
package fragmentation
import (
- "container/heap"
- "fmt"
"math"
+ "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,9 +25,11 @@ import (
)
type hole struct {
- first uint16
- last uint16
- deleted bool
+ first uint16
+ last uint16
+ filled bool
+ final bool
+ data buffer.View
}
type reassembler struct {
@@ -38,8 +39,7 @@ type reassembler struct {
proto uint8
mu sync.Mutex
holes []hole
- deleted int
- heap fragHeap
+ filled int
done bool
creationTime int64
pkt *stack.PacketBuffer
@@ -48,49 +48,94 @@ type reassembler struct {
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
- holes: make([]hole, 0, 16),
- heap: make(fragHeap, 0, 8),
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
- first: 0,
- last: math.MaxUint16,
- deleted: false})
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ final: true,
+ })
return r
}
-// updateHoles updates the list of holes for an incoming fragment and
-// returns true iff the fragment filled at least part of an existing hole.
-func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
- used := false
- for i := range r.holes {
- if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
- continue
- }
- used = true
- r.deleted++
- r.holes[i].deleted = true
- if first > r.holes[i].first {
- r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
- }
- if last < r.holes[i].last && more {
- r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
- }
- }
- return used
-}
-
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
- consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, 0, nil
}
- if r.updateHoles(first, last, more) {
+
+ var holeFound bool
+ var consumed int
+ for i := range r.holes {
+ currentHole := &r.holes[i]
+
+ if last < currentHole.first || currentHole.last < first {
+ continue
+ }
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
+ }
+
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ final: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ final: currentHole.final,
+ })
+ currentHole.final = false
+ }
+ v := pkt.Data.ToOwnedView()
+ consumed = v.Size()
+ r.size += consumed
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
+ final: currentHole.final,
+ data: v,
+ }
+ r.filled++
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
@@ -103,21 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
r.pkt = pkt
r.proto = proto
}
- vv := pkt.Data
- // We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
- consumed = vv.Size()
- r.size += consumed
+
+ break
+ }
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
}
- // Check if all the holes have been deleted and we are ready to reassamble.
- if r.deleted < len(r.holes) {
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- res, err := r.heap.reassemble()
- if err != nil {
- return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ var size int
+ views := make([]buffer.View, 0, len(r.holes))
+ for _, hole := range r.holes {
+ views = append(views, hole.data)
+ size += hole.data.Size()
}
- return res, r.proto, true, consumed, nil
+ return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..2ff03eeeb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -16,92 +16,175 @@ package fragmentation
import (
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
+type processParams struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ wantDone bool
+ wantError error
}
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(size int) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v(size).ToVectorisedView(),
+ })
+ }
+
+ var tests = []struct {
+ name string
+ params []processParams
+ want []hole
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
},
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
+ {
+ name: "One fragment at beginning",
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
+ {
+ name: "One fragment in the middle",
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
+ {
+ name: "One fragment at the end",
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: true, data: v(2)},
+ {first: 0, last: 0, filled: false},
+ },
},
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "One fragment completing a packet",
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 3, filled: true, final: false, data: v(4)},
+ {first: 4, last: 5, filled: true, final: true, data: v(2)},
+ },
},
- },
-}
+ {
+ name: "Two overlapping fragments",
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ }
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ for _, param := range test.params {
+ _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
+ }
+ }
+ if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
new file mode 100644
index 000000000..ca1247c1e
--- /dev/null
+++ b/pkg/tcpip/network/ip/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ip",
+ srcs = ["generic_multicast_protocol.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ ],
+)
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = ["generic_multicast_protocol_test.go"],
+ deps = [
+ ":ip",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/faketime",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
new file mode 100644
index 000000000..f2f0e069c
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -0,0 +1,676 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// hostState is the state a host may be in for a multicast group.
+type hostState int
+
+// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
+// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
+// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
+const (
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
+ // all network interfaces; it requires no storage in the host."
+ //
+ // 'Non-Listener' is the MLDv1 term used to describe this state.
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
+ // delayingMember is the "'Delaying Member' state, when the host belongs to
+ // the group on the interface and has a report delay timer running for that
+ // membership."
+ //
+ // 'Delaying Listener' is the MLDv1 term used to describe this state.
+ delayingMember
+
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
+ // idleMember is the "Idle Member" state, when the host belongs to the group
+ // on the interface and does not have a report delay timer running for that
+ // membership.
+ //
+ // 'Idle Listener' is the MLDv1 term used to describe this state.
+ idleMember
+)
+
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
+// multicastGroupState holds the Generic Multicast Protocol state for a
+// multicast group.
+type multicastGroupState struct {
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
+ state hostState
+
+ // lastToSendReport is true if we sent the last report for the group. It is
+ // used to track whether there are other hosts on the subnet that are also
+ // members of the group.
+ //
+ // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page
+ // 8 for MLDv1.
+ lastToSendReport bool
+
+ // delayedReportJob is used to delay sending responses to membership report
+ // messages in order to reduce duplicate reports from multiple hosts on the
+ // interface.
+ //
+ // Must not be nil.
+ delayedReportJob *tcpip.Job
+}
+
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
+// MulticastGroupProtocol is a multicast group protocol whose core state machine
+// can be represented by GenericMulticastProtocolState.
+type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
+ // SendReport sends a multicast report for the specified group address.
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+
+ // SendLeave sends a multicast leave for the specified group address.
+ SendLeave(groupAddress tcpip.Address) *tcpip.Error
+}
+
+// GenericMulticastProtocolState is the per interface generic multicast protocol
+// state.
+//
+// There is actually no protocol named "Generic Multicast Protocol". Instead,
+// the term used to refer to a generic multicast protocol that applies to both
+// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
+// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
+//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
+// GenericMulticastProtocolState.Init MUST be called before calling any of
+// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
+type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
+ opts GenericMulticastProtocolOptions
+
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
+}
+
+// Init initializes the Generic Multicast Protocol state.
+//
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
+}
+
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
+ // The group has already been joined.
+ info.joins++
+ g.memberships[groupAddress] = info
+ return
+ }
+
+ info := multicastGroupState{
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
+ }
+
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }),
+ }
+
+ if g.opts.Protocol.Enabled() {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.memberships[groupAddress] = info
+}
+
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
+ return ok
+}
+
+// LeaveGroupLocked handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ return false
+ }
+
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.memberships[groupAddress] = info
+ return true
+ }
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.memberships, groupAddress)
+ return true
+}
+
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
+//
+// If the group address is unspecified, then reports will be scheduled for all
+// joined groups.
+//
+// Report(s) will be scheduled to be sent after a random duration between 0 and
+// the maximum response time.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 2.4 (for IGMPv2),
+ //
+ // In a Membership Query message, the group address field is set to zero
+ // when sending a General Query, and set to the group address being
+ // queried when sending a Group-Specific Query.
+ //
+ // As per RFC 2710 section 3.6 (for MLDv1),
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ if groupAddress.Unspecified() {
+ // This is a general query as the group address is unspecified.
+ for groupAddress, info := range g.memberships {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+ } else if info, ok := g.memberships[groupAddress]; ok {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// HandleReportLocked handles a report message.
+//
+// If the report is for a joined group, any active delayed report will be
+// cancelled and the host state for the group transitions to idle.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
+ //
+ // If the host receives another host's Report (version 1 or 2) while it has
+ // a timer running, it stops its timer for the specified group and does not
+ // send a Report
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // If a node receives another node's Report from an interface for a
+ // multicast address while it has a timer running for that same address
+ // on that interface, it stops its timer and does not send a Report for
+ // that address, thus suppressing duplicate reports on the link.
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
+ info.delayedReportJob.Cancel()
+ info.lastToSendReport = false
+ info.state = idleMember
+ g.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ info.lastToSendReport = false
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ info.state = idleMember
+ return
+ }
+
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
+// setDelayTimerForAddressRLocked sets timer to send a delay report.
+//
+// Precondition: g.protocolMU MUST be read locked.
+func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // If a timer for the group is already unning, it is reset to the random
+ // value only if the requested Max Response Time is less than the remaining
+ // value of the running timer.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // If a timer for any address is already running, it is reset to the new
+ // random value only if the requested Maximum Response Delay is less than
+ // the remaining value of the running timer.
+ if info.state == delayingMember {
+ // TODO: Reset the timer if time remaining is greater than maxResponseTime.
+ return
+ }
+
+ info.state = delayingMember
+ info.delayedReportJob.Cancel()
+ info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
+}
+
+// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
+func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration {
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // When a host receives a Group-Specific Query, it sets a delay timer to a
+ // random value selected from the range (0, Max Response Time]...
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node receives a Multicast-Address-Specific Query, if it is
+ // listening to the queried Multicast Address on the interface from
+ // which the Query was received, it sets a delay timer for that address
+ // to a random value selected from the range [0, Maximum Response Delay],
+ // as above.
+ if maxRespTime == 0 {
+ return 0
+ }
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
+}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
new file mode 100644
index 000000000..f56f7aa90
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -0,0 +1,877 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+)
+
+const (
+ addr1 = tcpip.Address("\x01")
+ addr2 = tcpip.Address("\x02")
+ addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
+)
+
+var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
+
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu sync.RWMutex
+
+ // Must only be accessed with mu held.
+ sendReportGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ makeQueuePackets bool
+
+ // Must only be accessed with mu held.
+ disabled bool
+}
+
+func (m *mockMulticastGroupProtocol) init() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.disabled = !v
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ return !m.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.sendReportGroupAddrCount[groupAddress]++
+ return !m.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.sendLeaveGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendReportGroupAddresses {
+ sendReportGroupAddrCount[a] = 1
+ }
+
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddrCount[a] = 1
+ }
+
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".mu", ".t", ".makeQueuePackets", ".disabled":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
+ return diff
+}
+
+func TestJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if test.shouldSendReports {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if test.shouldSendMessages {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ {
+ mgp.mu.Lock()
+ res := g.LeaveGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr)
+ }
+ }
+ if test.shouldSendMessages {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleReport(t *testing.T) {
+ tests := []struct {
+ name string
+ reportAddr tcpip.Address
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ reportAddr: "",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ reportAddr: "\x00",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ reportAddr: addr1,
+ expectReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ reportAddr: addr3,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report for a group we have a timer scheduled for should
+ // cancel our delayed report timer for the group.
+ mgp.mu.Lock()
+ g.HandleReportLocked(test.reportAddr)
+ mgp.mu.Unlock()
+ if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectReportsFor: []tcpip.Address{addr1},
+ },
+ {
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectReportsFor: nil,
+ },
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should make us schedule a new delayed report if it
+ // is a query directed at us or a general query.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(test.queryAddr, test.maxDelay)
+ mgp.mu.Unlock()
+ if len(test.expectReportsFor) != 0 {
+ clock.Advance(test.maxDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestJoinCount(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should still be considered joined after leaving once.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if !isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving once more should actually remove us from the group.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ mgp.mu.Lock()
+ g.MakeAllNonMemberLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ mgp.mu.RLock()
+ res := g.IsLocallyJoinedRLocked(group)
+ mgp.mu.RUnlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ mgp.mu.Lock()
+ g.InitializeGroupsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ mgp.setEnabled(false)
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining groups should not send any reports.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ res := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should not send any reports.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving groups should not send any leave messages.
+ {
+ mgp.mu.Lock()
+ addr2LeaveRes := g.LeaveGroupLocked(addr2)
+ addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
+ addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !addr2LeaveRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
+ }
+ if !addr1IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ if addr2IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestQueuedPackets(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr1)
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d49c44846..3005973d7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -193,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -207,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -223,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -348,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -554,7 +550,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -623,11 +619,11 @@ func TestReceive(t *testing.T) {
view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- NextHeader: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
+ PayloadLength: payloadLen,
+ TransportProtocol: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -937,7 +933,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -997,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -1011,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp.SetIdent(0xdead)
icmp.SetSequence(0xbeef)
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- })
-
+ var extHdrs header.IPv6ExtHdrSerializer
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
+ ExtensionHeaders: extHdrs,
+ })
+
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
@@ -1093,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
dataBuf := [dataLen]byte{1, 2, 3, 4}
data := dataBuf[:]
- ipv4Options := header.IPv4Options{0, 1, 0, 1}
+ ipv4Options := header.IPv4OptionsSerializer{
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ }
+
+ expectOptions := header.IPv4Options{
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ }
ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
@@ -1243,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
+ ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1266,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1276,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1288,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
@@ -1307,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1317,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1336,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1379,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ // NB: we're lying about transport protocol here to verify the raw
+ // fragment header bytes.
+ TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1414,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1449,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 6252614ec..32f53f217 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ipv4",
srcs = [
"icmp.go",
+ "igmp.go",
"ipv4.go",
],
visibility = ["//visibility:public"],
@@ -17,6 +18,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -24,7 +26,10 @@ go_library(
go_test(
name = "ipv4_test",
size = "small",
- srcs = ["ipv4_test.go"],
+ srcs = [
+ "igmp_test.go",
+ "ipv4_test.go",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 488945226..8e392f86c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4PacketsReceived
+ received := stats.ICMP.V4.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.Echo.Increment()
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
@@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4PacketsSent
+ sent := p.stack.Stats().ICMP.V4.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
new file mode 100644
index 000000000..da88d65d1
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -0,0 +1,345 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // igmpV1PresentDefault is the initial state for igmpV1Present in the
+ // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is
+ // the initial state."
+ igmpV1PresentDefault = 0
+
+ // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18
+ // See note on igmpState.igmpV1Present for more detail.
+ v1RouterPresentTimeout = 400 * time.Second
+
+ // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router
+ // will send General Queries with the Max Response Time set to 0. This MUST
+ // be interpreted as a value of 100 (10 seconds)."
+ //
+ // Note that the Max Response Time field is a value in units of deciseconds.
+ v1MaxRespTime = 10 * time.Second
+
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited IGMP reports.
+ //
+ // Obtained from RFC 2236 Section 8.10, Page 19.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
+
+// igmpState is the per-interface IGMP state.
+//
+// igmpState.init() MUST be called after creating an IGMP state.
+type igmpState struct {
+ // The IPv4 endpoint this igmpState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+
+ // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
+ // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
+ // Membership Reports in response to its Queries, and will not pay
+ // attention to Version 2 Membership Reports. Therefore, a state variable
+ // MUST be kept for each interface, describing whether the multicast
+ // Querier on that interface is running IGMPv1 or IGMPv2. This variable
+ // MUST be based upon whether or not an IGMPv1 query was heard in the last
+ // [Version 1 Router Present Timeout] seconds".
+ //
+ // Must be accessed with atomic operations. Holds a value of 1 when true, 0
+ // when false.
+ igmpV1Present uint32
+
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ igmpType := header.IGMPv2MembershipReport
+ if igmp.v1Present() {
+ igmpType = header.IGMPv1MembershipReport
+ }
+ return igmp.writePacket(groupAddress, groupAddress, igmpType)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ // As per RFC 2236 Section 6, Page 8: "If the interface state says the
+ // Querier is running IGMPv1, this action SHOULD be skipped. If the flag
+ // saying we were the last host to report is cleared, this action MAY be
+ // skipped."
+ if igmp.v1Present() {
+ return nil
+ }
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
+}
+
+// init sets up an igmpState struct, and is required to be called before using
+// a new igmpState.
+//
+// Must only be called once for the lifetime of igmp.
+func (igmp *igmpState) init(ep *endpoint) {
+ igmp.ep = ep
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
+ igmp.igmpV1Present = igmpV1PresentDefault
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
+ igmp.setV1Present(false)
+ })
+}
+
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
+ stats := igmp.ep.protocol.stack.Stats()
+ received := stats.IGMP.PacketsReceived
+ headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.IGMP(headerView)
+
+ // Temporarily reset the checksum field to 0 in order to calculate the proper
+ // checksum.
+ wantChecksum := h.Checksum()
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(wantChecksum)
+
+ if gotChecksum != wantChecksum {
+ received.ChecksumErrors.Increment()
+ return
+ }
+
+ switch h.Type() {
+ case header.IGMPMembershipQuery:
+ received.MembershipQuery.Increment()
+ if len(headerView) < header.IGMPQueryMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
+ case header.IGMPv1MembershipReport:
+ received.V1MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPv2MembershipReport:
+ received.V2MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPLeaveGroup:
+ received.LeaveGroup.Increment()
+ // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
+ // Report, are ignored in all states"
+
+ default:
+ // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
+ // be silently ignored. New message types may be used by newer versions of
+ // IGMP, by multicast routing protocols, or other uses."
+ received.Unrecognized.Increment()
+ }
+}
+
+func (igmp *igmpState) v1Present() bool {
+ return atomic.LoadUint32(&igmp.igmpV1Present) == 1
+}
+
+func (igmp *igmpState) setV1Present(v bool) {
+ if v {
+ atomic.StoreUint32(&igmp.igmpV1Present, 1)
+ } else {
+ atomic.StoreUint32(&igmp.igmpV1Present, 0)
+ }
+}
+
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
+ // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
+ // then change the state to note that an IGMPv1 router is present and
+ // schedule the query received Job.
+ if maxRespTime == 0 && igmp.Enabled() {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ igmp.setV1Present(true)
+ maxRespTime = v1MaxRespTime
+ }
+
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
+}
+
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
+}
+
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+ igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
+ igmpData.SetType(igmpType)
+ igmpData.SetGroupAddress(groupAddress)
+ igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
+ Data: buffer.View(igmpData).ToVectorisedView(),
+ })
+
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
+ }
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.IGMPProtocolNumber,
+ TTL: header.IGMPTTL,
+ TOS: stack.DefaultTOS,
+ }, header.IPv4OptionsSerializer{
+ &header.IPv4SerializableRouterAlertOption{},
+ })
+
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ switch igmpType {
+ case header.IGMPv1MembershipReport:
+ sentStats.V1MembershipReport.Increment()
+ case header.IGMPv2MembershipReport:
+ sentStats.V2MembershipReport.Increment()
+ case header.IGMPLeaveGroup:
+ sentStats.LeaveGroup.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
+ }
+ return true, nil
+}
+
+// joinGroup handles adding a new group to the membership map, setting up the
+// IGMP state for the group, and sending and scheduling the required
+// messages.
+//
+// If the group already exists in the membership map, returns
+// tcpip.ErrDuplicateAddress.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Leave Group message
+// if required.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) initializeAll() {
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
new file mode 100644
index 000000000..1ee573ac8
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -0,0 +1,215 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
+ multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
+ nicID = 1
+)
+
+// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set. Raises a t.Error if
+// any field does not match.
+func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(igmpType),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(1, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ return e, s, clock
+}
+
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: 1,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(igmpType)
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
+// present on the network. The IGMP stack will then send IGMPv1 Membership
+// reports for backwards compatibility.
+func TestIgmpV1Present(t *testing.T) {
+ e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
+ }
+
+ // This NIC will send an IGMPv2 report immediately, before this test can get
+ // the IGMPv1 General Membership Query in.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Inject an IGMPv1 General Membership Query which is identical to a standard
+ // membership query except the Max Response Time is set to 0, which will tell
+ // the stack that this is a router using IGMPv1. Send it to the all systems
+ // group which is the only group this host belongs to.
+ createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems)
+ if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
+ t.Fatalf("got Membership Queries received = %d, want = 1", got)
+ }
+
+ // Before advancing the clock, verify that this host has not sent a
+ // V1MembershipReport yet.
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
+ }
+
+ // Verify the solicited Membership Report is sent. Now that this NIC has seen
+ // an IGMPv1 query, it should send an IGMPv1 Membership Report.
+ p, ok = e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ p, ok = e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
+}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1efe6297a..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -83,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -93,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -121,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.igmp.initializeAll()
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -157,19 +172,27 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ e.mu.igmp.softLeaveAll()
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -198,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
hdrLen := header.IPv4MinimumSize
- var opts header.IPv4Options
- if params.Options != nil {
- var ok bool
- if opts, ok = params.Options.(header.IPv4Options); !ok {
- panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
- }
- hdrLen += opts.SizeWithPadding()
- if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
- }
+ var optLen int
+ if options != nil {
+ optLen = int(options.Length())
+ }
+ hdrLen += optLen
+ if hdrLen > header.IPv4MaximumHeaderSize {
+ // Since we have no way to report an error we must either panic or create
+ // a packet which is different to what was requested. Choose panic as this
+ // would be a programming error that should be caught in testing.
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := uint16(pkt.Size())
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- Options: opts,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ Options: options,
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -259,7 +279,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -347,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -461,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -566,21 +586,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
- srcAddr := h.SourceAddress()
- dstAddr := h.DestinationAddress()
-
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- _ = e.forwardPacket(pkt)
- return
- }
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
- addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -608,16 +613,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
+ if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ // Make sure the source address is not a subnet-local broadcast address.
+ if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ if subnet.IsBroadcast(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ } else if !e.IsInGroup(dstAddr) {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
- pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ _ = e.forwardPacket(pkt)
+ return
+ }
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
@@ -692,6 +723,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
+ if p == header.IGMPProtocolNumber {
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
+ return
+ }
if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
@@ -747,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -770,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
defer e.mu.Unlock()
loopback := e.nic.IsLoopback()
- addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
- })
- if addressEndpoint != nil {
- return addressEndpoint
- }
-
- if !allowTemp {
- return nil
- }
-
- addr := localAddr.WithPrefix()
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
- if err != nil {
- // AddAddress only returns an error if the address is already assigned,
- // but we just checked above if the address exists so we expect no error.
- panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
- }
- return addressEndpoint
+ }, allowTemp, tempPEB)
}
// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
@@ -816,28 +850,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -863,6 +912,8 @@ type protocol struct {
hashIV uint32
fragmentation *fragmentation.Fragmentation
+
+ options Options
}
// Number returns the ipv4 protocol number.
@@ -987,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+// Options holds options to configure a new protocol.
+type Options struct {
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
+}
+
+// NewProtocolWithOptions returns an IPv4 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -1007,14 +1064,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- p := &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ options: opts,
+ }
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
- p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
- return p
+}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
}
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
@@ -1129,6 +1194,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
}
pointer := tsOpt.Pointer()
+ // RFC 791 page 22 states: "The smallest legal value is 5."
+ // Since the pointer is 1 based, and the header is 4 bytes long the
+ // pointer must point beyond the header therefore 4 or less is bad.
+ if pointer <= header.IPv4OptionTimestampHdrLength {
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
// calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
@@ -1215,7 +1286,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
}
- nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+ pointer := rrOpt.Pointer()
+ // RFC 791 page 20 states:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ // Since the pointer is 1 based, and the header is 3 bytes long the
+ // pointer must point beyond the header therefore 3 or less is bad.
+ if pointer <= header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
// RFC 791 page 21 says
// If the route data area is already full (the pointer exceeds the
@@ -1230,14 +1309,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// do this (as do most implementations). It is probable that the inclusion
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
- if nextSlot >= optlen {
+ if pointer > optlen {
return 0, nil
}
// The data area isn't full but there isn't room for a new entry.
// Either Length or Pointer could be bad. We must select Pointer for Linux
- // compatibility, even if only the length is bad.
- if nextSlot+header.IPv4AddressSize > optlen {
+ // compatibility, even if only the length is bad. NB. pointer is 1 based.
+ if pointer+header.IPv4AddressSize > optlen+1 {
if false {
// This is what we would do if we were not being Linux compatible.
// Check for bad pointer or length value. Must be a multiple of 4 after
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4e4e1f3b4..9e2d2cfd6 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -103,105 +103,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
-// fields when options are supplied.
-func TestIPv4EncodeOptions(t *testing.T) {
- tests := []struct {
- name string
- options header.IPv4Options
- encodedOptions header.IPv4Options // reply should look like this
- wantIHL int
- }{
- {
- name: "valid no options",
- wantIHL: header.IPv4MinimumSize,
- },
- {
- name: "one byte options",
- options: header.IPv4Options{1},
- encodedOptions: header.IPv4Options{1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "two byte options",
- options: header.IPv4Options{1, 1},
- encodedOptions: header.IPv4Options{1, 1, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "three byte options",
- options: header.IPv4Options{1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "four byte options",
- options: header.IPv4Options{1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "five byte options",
- options: header.IPv4Options{1, 1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 8,
- },
- {
- name: "thirty nine byte options",
- options: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39,
- },
- encodedOptions: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39, 0,
- },
- wantIHL: header.IPv4MinimumSize + 40,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength)
- hdr := buffer.NewPrependable(int(totalLen))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- // To check the padding works, poison the last byte of the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0xff
- ip.SetHeaderLength(0)
- }
- ip.Encode(&header.IPv4Fields{
- Options: test.options,
- })
- options := ip.Options()
- wantOptions := test.encodedOptions
- if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
- t.Errorf("got IHL of %d, want %d", got, want)
- }
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(wantOptions) == 0 && len(options) == 0 {
- return
- }
-
- if diff := cmp.Diff(wantOptions, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
func TestForwarding(t *testing.T) {
const (
nicID1 = 1
@@ -453,14 +354,6 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
- name: "Check option padding",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 0},
- },
- {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -583,7 +476,7 @@ func TestIPv4Sanity(t *testing.T) {
68, 7, 5, 0,
// ^ ^ Linux points here which is wrong.
// | Not a multiple of 4
- 1, 2, 3,
+ 1, 2, 3, 0,
},
shouldFail: true,
expectErrorICMP: true,
@@ -662,6 +555,56 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
+ // Timestamp pointer uses one based counting so 0 is invalid.
+ name: "timestamp pointer invalid",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, 0, 0x00,
+ // ^ 0 instead of 5 or more.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Timestamp pointer cannot be less than 5. It must point past the header
+ // which is 4 bytes. (1 based counting)
+ name: "timestamp pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
+ // ^ header is 4 bytes, so 4 should fail.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "valid timestamp pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
+ // ^ header is 4 bytes, so 5 should succeed.
+ 0, 0, 0, 0,
+ },
+ replyOptions: header.IPv4Options{
+ 68, 8, 9, 0x00,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
// Needs 8 bytes for a type 1 timestamp but there are only 4 free.
name: "bad timer element alignment",
maxTotalLength: ipv4.MaxTotalSize,
@@ -792,7 +735,61 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
- // Confirm linux bug for bug compatibility.
+ // Pointer uses one based counting so 0 is invalid.
+ name: "record route pointer zero",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, 0, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. 3 should fail.
+ name: "record route pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. Check 4 passes. (Duplicates "single
+ // record route with room")
+ name: "valid record route pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: header.IPv4Options{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm Linux bug for bug compatibility.
// Linux returns slot 22 but the error is in slot 21.
name: "multiple record route with not enough room",
maxTotalLength: ipv4.MaxTotalSize,
@@ -863,8 +860,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if len(test.options)%4 != 0 {
+ t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
+ }
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
@@ -883,11 +882,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
- // To check the padding works, poison the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0x01
- }
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
@@ -895,10 +889,19 @@ func TestIPv4Sanity(t *testing.T) {
TTL: test.TTL,
SrcAddr: remoteIPv4Addr,
DstAddr: ipv4Addr.Address,
- Options: test.options,
})
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
+ } else {
+ // Set the calculated header length, since we may manually add options.
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ }
+ if len(test.options) != 0 {
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
}
ip.SetChecksum(0)
ipHeaderChecksum := ip.CalculateChecksum()
@@ -1003,7 +1006,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - paddedOptionLength
+ sizeChange := len(test.replyOptions) - len(test.options)
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2320,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
+ {
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -2513,7 +2538,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2530,7 +2555,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
@@ -2644,8 +2669,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2687,8 +2712,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2736,8 +2761,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != arp.ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress)
}
rep := header.ARP(p.Pkt.NetworkHeader().View())
if got := rep.Op(); got != header.ARPRequest {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,19 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index beb8f562e..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6PacketsSent
- received := stats.V6PacketsReceived
+ sent := stats.V6.PacketsSent
+ received := stats.V6.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6PacketsSent
+ stat := p.stack.Stats().ICMP.V6.PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
@@ -796,7 +827,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
+ isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
+ if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -812,8 +844,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
+ //
+ // If the packet was originally destined to a multicast address, then do not
+ // use the packet's destination address as the source for the response ICMP
+ // packet as "multicast addresses must not be used as source addresses in IPv6
+ // packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -827,7 +864,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
defer route.Release()
stats := p.stack.Stats().ICMP
- sent := stats.V6PacketsSent
+ sent := stats.V6.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9bc02d851..02b18e9a5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -280,11 +296,11 @@ func TestICMPCounts(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -301,7 +317,7 @@ func TestICMPCounts(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -422,11 +454,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -443,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -568,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr {
+ t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr)
}
// Pull the full payload since network header. Needed for header.IPv6 to
@@ -821,11 +853,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
@@ -833,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -898,11 +930,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1016,11 +1048,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(icmpSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1028,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1076,11 +1108,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1195,11 +1227,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(size + payloadSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
@@ -1207,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1349,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
@@ -1413,11 +1445,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1431,8 +1463,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1455,11 +1487,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1473,8 +1505,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1524,8 +1556,8 @@ func TestPacketQueing(t *testing.T) {
t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1554,11 +1586,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1796,11 +1828,11 @@ func TestCallsToNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
ep.HandlePacket(pkt)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 7a00f6314..a49b5ac77 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -19,6 +19,7 @@ import (
"encoding/binary"
"fmt"
"hash/fnv"
+ "math"
"sort"
"sync/atomic"
"time"
@@ -34,7 +35,9 @@ import (
)
const (
+ // ReassembleTimeout controls how long a fragment will be held.
// As per RFC 8200 section 4.5:
+ //
// If insufficient fragments are received to complete reassembly of a packet
// within 60 seconds of the reception of the first-arriving fragment of that
// packet, reassembly of that packet must be abandoned.
@@ -83,6 +86,7 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
}
@@ -118,6 +122,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -224,6 +267,12 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.mld.initializeAll()
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -241,8 +290,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -251,7 +302,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err *tcpip.Error
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
return true
@@ -273,7 +324,7 @@ func (e *endpoint) Enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
// link-local address.
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
@@ -322,7 +373,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -331,9 +382,17 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -341,7 +400,7 @@ func (e *endpoint) disableLocked() {
// Precondition: e.mu must be write locked.
func (e *endpoint) stopDADForPermanentAddressesLocked() {
// Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.GetKind() != stack.PermanentTentative {
return true
}
@@ -373,19 +432,27 @@ func (e *endpoint) MTU() uint32 {
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
+ // TODO(gvisor.dev/issues/5035): The maximum header length returned here does
+ // not open the possibility for the caller to know about size required for
+ // extension headers.
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+ extHdrsLen := extensionHeaders.Length()
+ length := pkt.Size() + extensionHeaders.Length()
+ if length > math.MaxUint16 {
+ panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ }
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(params.Protocol),
- HopLimit: params.TTL,
- TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(length),
+ TransportProtocol: params.Protocol,
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -440,7 +507,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -529,7 +596,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -737,8 +804,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ } else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment()
return
@@ -747,7 +817,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
_ = e.forwardPacket(pkt)
return
}
- addressEndpoint.DecRef()
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1090,9 +1159,16 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
+ // The location of the Next Header field is in a different place in
+ // the initial IPv6 header than it is in the extension headers so
+ // treat it specially.
+ prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset)
+ if previousHeaderStart != 0 {
+ prevHdrIDOffset = previousHeaderStart
+ }
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
+ pointer: prevHdrIDOffset,
}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
@@ -1100,12 +1176,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
- }, pkt)
- stats.UnknownProtocolRcvdPackets.Increment()
- return
+ // Since the iterator returns IPv6RawPayloadHeader for unknown Extension
+ // Header IDs this should never happen unless we missed a supported type
+ // here.
+ panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr))
+
}
}
}
@@ -1153,11 +1228,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
- return nil, err
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1166,6 +1236,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1211,7 +1288,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1234,7 +1312,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
//
// Precondition: e.mu must be read or write locked.
func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
- return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+ return e.mu.addressableEndpointState.GetAddress(localAddr)
}
// MainAddress implements stack.AddressableEndpoint.
@@ -1266,6 +1344,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1285,10 +1383,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1304,6 +1402,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
addressEndpoint: addressEndpoint,
scope: scope,
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
@@ -1376,28 +1476,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1405,7 +1520,8 @@ var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
- stack *stack.Stack
+ stack *stack.Stack
+ options Options
mu struct {
sync.RWMutex
@@ -1429,26 +1545,6 @@ type protocol struct {
forwarding uint32
fragmentation *fragmentation.Fragmentation
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
- // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
- ndpConfigs NDPConfigurations
-
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
- // autoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -1481,16 +1577,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
@@ -1613,17 +1704,17 @@ type Options struct {
// NDPConfigs is the default NDP configurations used by interfaces.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // AutoGenLinkLocal determines whether or not the stack attempts to
+ // auto-generate a link-local address for newly enabled non-loopback
// NICs.
//
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate an IPv6 link-local adddress.
+ // further attempts are made to auto-generate a link-local adddress.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
- AutoGenIPv6LinkLocal bool
+ AutoGenLinkLocal bool
// NDPDisp is the NDP event dispatcher that an integrator can provide to
// receive NDP related events.
@@ -1647,6 +1738,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
@@ -1658,15 +1752,11 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
+ stack: s,
+ options: opts,
+
ids: ids,
hashIV: hashIV,
-
- ndpDisp: opts.NDPDisp,
- ndpConfigs: opts.NDPConfigs,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
}
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
@@ -1712,24 +1802,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
fragPkt.NetworkProtocolNumber = ProtocolNumber
originalIPHeadersLength := len(originalIPHeaders)
- fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+
+ s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ M: more,
+ Identification: id,
+ }}
+
+ fragmentIPHeadersLength := originalIPHeadersLength + s.Length()
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
- fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
}
- fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
- fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
- fragmentHeader.Encode(&header.IPv6FragmentFields{
- M: more,
- FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
- Identification: id,
- NextHeader: uint8(transportProto),
- })
+ nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:])
+
+ fragmentIPHeaders.SetNextHeader(nextHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index a671d4bac..5f07d3af8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -51,6 +51,7 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+ unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
extraHeaderReserve = 50
)
@@ -68,18 +69,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
if got := stats.NeighborAdvert.Value(); got != want {
t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
@@ -126,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -573,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
expectICMP: false,
},
{
+ name: "unknown next header (first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, unknownHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6NextHeaderOffset,
+ },
+ {
+ name: "unknown next header (not first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ unknownHdrID, 0,
+ 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
name: "destination with unknown option skippable action",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -755,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
pointer: header.IPv6FixedHeaderSize,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
- shouldAccept: false,
- },
- {
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -873,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), udpPayload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
+
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
sum = header.Checksum(udpPayload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
@@ -884,16 +913,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: ipv6NextHdr,
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
+ // We're lying about transport protocol here to be able to generate
+ // raw extension headers from the test definitions.
+ TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -982,9 +1009,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- fragmentExtHdrLen = 8
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
+ fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
routingExtHdrLen = 8
@@ -1328,14 +1356,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+65520,
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[:65520],
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
),
},
@@ -1344,14 +1372,17 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+ // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ udpMaximumSizeMinus15 & 0xff,
+ 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[65520:],
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
),
},
@@ -1359,6 +1390,47 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
{
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ (udpMaximumSizeMinus15 & 0xff) + 1,
+ 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1877,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
+ // We're lying about transport protocol here so that we can generate
+ // raw extension headers for the tests.
+ TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
@@ -1925,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -1944,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0 >> 3,
M: true,
Identification: ident,
@@ -1971,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
M: false,
Identification: ident,
@@ -2019,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2084,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2098,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2120,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2136,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2158,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2180,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2196,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2218,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2234,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2280,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2439,7 +2503,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2456,7 +2520,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
@@ -2924,11 +2988,11 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..e8d1e7a79
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
+ mld.ep = ep
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
+}
+
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
+}
+
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ extensionHeaders := header.IPv6ExtHdrSerializer{
+ header.IPv6SerializableHopByHopExtHdr{
+ &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
+ },
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ }, extensionHeaders)
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ mldStat.Increment()
+ return localAddress != header.IPv6Any, nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..e2778b656
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,297 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+)
+
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6WithExtHdr(t, p,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ }
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for _ = range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 40da011f8..d515eb622 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -471,17 +475,8 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- rtrSolicit struct {
- // The timer used to send the next router solicitation message.
- timer tcpip.Timer
-
- // Used to let the Router Solicitation timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this ndpState is associated with. MUST be set when the
- // timer is set.
- done *bool
- }
+ // The job used to send the next router solicitation message.
+ rtrSolicitJob *tcpip.Job
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -507,7 +502,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer tcpip.Timer
+ job *tcpip.Job
// Used to let the DAD timer know that it has been stopped.
//
@@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
- var done bool
- var timer tcpip.Timer
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
- ndp.ep.mu.Lock()
- defer ndp.ep.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the IPv6 endpoint lock and stopped
- // DAD before this function obtained the NIC lock. Simply return here and
- // do nothing further.
- return
- }
+ state := dadState{
+ job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.dad[addr]
+ if !ok {
+ panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- dadDone := remaining == 0
-
- var err *tcpip.Error
- if !dadDone {
- // Use the unspecified address as the source address when performing DAD.
- addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
-
- // Do not hold the lock when sending packets which may be a long running
- // task or may block link address resolution. We know this is safe
- // because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the IPv6 endpoint. Note,
- // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
- // the address was removed.
- ndp.ep.mu.Unlock()
- err = ndp.sendDADPacket(addr, addressEndpoint)
- ndp.ep.mu.Lock()
- addressEndpoint.DecRef()
- }
+ dadDone := remaining == 0
- if done {
- // If we reach this point, it means that DAD was stopped after we released
- // the IPv6 endpoint's read lock and before we obtained the write lock.
- return
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ }
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- timer.Reset(ndp.configs.RetransmitTimer)
- return
- }
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ state.job.Schedule(ndp.configs.RetransmitTimer)
+ return
+ }
- // At this point we know that either DAD is done or we hit an error sending
- // the last NDP NS. Either way, clean up addr's DAD state and let the
- // integrator know DAD has completed.
- delete(ndp.dad, addr)
+ // At this point we know that either DAD is done or we hit an error
+ // sending the last NDP NS. Either way, clean up addr's DAD state and let
+ // the integrator know DAD has completed.
+ delete(ndp.dad, addr)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
- }
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
- })
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
- ndp.dad[addr] = dadState{
- timer: timer,
- done: &done,
+ ndp.ep.onAddressAssignedLocked(addr)
+ }
+ }),
}
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ state.job.Schedule(0)
+ ndp.dad[addr] = state
+
return nil
}
@@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
-
- // Route should resolve immediately since snmc is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is not
- // locked, the NIC could have been disabled or removed by another goroutine.
- if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
- return err
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
- }
-
- icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(addr)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
sent.NeighborSolicit.Increment()
-
return nil
}
@@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
return
}
- if dad.timer != nil {
- dad.timer.Stop()
- dad.timer = nil
-
- *dad.done = true
- dad.done = nil
- }
-
+ dad.job.Cancel()
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return nil
}
@@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
@@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
}
addressEndpoint.SetDeprecated(true)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicit.timer != nil {
+ if ndp.rtrSolicitJob != nil {
// We are already soliciting routers.
return
}
@@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- var done bool
- ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
- ndp.ep.mu.Lock()
- if done {
- // If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the IPv6 endpoint lock and stopped
- // solicitations. Simply return here and do nothing further.
- ndp.ep.mu.Unlock()
- return
- }
-
+ ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
- if addressEndpoint == nil {
- // Incase this ends up creating a new temporary address, we need to hold
- // onto the endpoint until a route is obtained. If we decrement the
- // reference count before obtaing a route, the address's resources would
- // be released and attempting to obtain a route after would fail. Once a
- // route is obtainted, it is safe to decrement the reference count since
- // obtaining a route increments the address's reference count.
- addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
- }
- ndp.ep.mu.Unlock()
-
- localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
- addressEndpoint.DecRef()
- if err != nil {
- return
- }
- defer r.Release()
-
- // Route should resolve immediately since
- // header.IPv6AllRoutersMulticastAddress is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is
- // not locked, the IPv6 endpoint could have been disabled or removed by
- // another goroutine.
- if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
- return
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() {
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
var optsSerializer header.NDPOptionsSerializer
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
@@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.ep.mu.Lock()
- if done || remaining == 0 {
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
- } else if ndp.rtrSolicit.timer != nil {
- // Note, we need to explicitly check to make sure that
- // the timer field is not nil because if it was nil but
- // we still reached this point, then we know the IPv6 endpoint
- // was requested to stop soliciting routers so we don't
- // need to send the next Router Solicitation message.
- ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ if remaining != 0 {
+ ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
}
- ndp.ep.mu.Unlock()
})
+ ndp.rtrSolicitJob.Schedule(delay)
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicit.timer == nil {
+ if ndp.rtrSolicitJob == nil {
// Nothing to do.
return
}
- *ndp.rtrSolicit.done = true
- ndp.rtrSolicit.timer.Stop()
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
+ ndp.rtrSolicitJob.Cancel()
+ ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 37e8b1083..05a0d95b2 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -213,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -319,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -599,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != respNSDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
}
- if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
})
e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != test.naDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
}
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -785,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -898,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) {
}
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
+ var extHdrs header.IPv6ExtHdrSerializer
if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
}
+ extHdrsLen := extHdrs.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
Data: payload.ToVectorisedView(),
})
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
})
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
ep.HandlePacket(pkt)
}
@@ -1122,7 +1118,7 @@ func TestNDPValidation(t *testing.T) {
s.SetForwarding(ProtocolNumber, true)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
rxRA := stats.RouterAdvert
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..05d98a0a5
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,1261 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6WithExtHdr(t, payload,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(ipv6Addr),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
+ return e, s, clock
+}
+
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mldEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
+
+ return s, clock
+}
+
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := 0; i < maxUnsolicitedReports; i++ {
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+
+ ipv6HeaderIter := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var transport header.IPv6RawPayloadHeader
+ for {
+ h, done, err := ipv6HeaderIter.Next()
+ if err != nil {
+ t.Fatalf("ipv6HeaderIter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ }
+ if t, ok := h.(header.IPv6RawPayloadHeader); ok {
+ transport = t
+ break
+ }
+ }
+
+ if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(transport.Buf.ToView())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i, _ := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
+// performed on loopback interfaces since they have no neighbours.
+func TestMGPDisabledOnLoopback(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
+
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 2a6c7c7c0..b60a5fd76 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,31 +15,350 @@
package tcpip
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/sync"
)
-// SocketOptions contains all the variables which store values for socket
-// level options.
+// SocketOptionsHandler holds methods that help define endpoint specific
+// behavior for socket level socket options. These must be implemented by
+// endpoints to get notified when socket level options are set.
+type SocketOptionsHandler interface {
+ // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
+ OnReuseAddressSet(v bool)
+
+ // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint.
+ OnReusePortSet(v bool)
+
+ // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
+ OnKeepAliveSet(v bool)
+
+ // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint.
+ // Note that v will be the inverse of TCP_NODELAY option.
+ OnDelayOptionSet(v bool)
+
+ // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
+ OnCorkOptionSet(v bool)
+
+ // LastError is invoked when SO_ERROR is read for an endpoint.
+ LastError() *Error
+}
+
+// DefaultSocketOptionsHandler is an embeddable type that implements no-op
+// implementations for SocketOptionsHandler methods.
+type DefaultSocketOptionsHandler struct{}
+
+var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil)
+
+// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet.
+func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {}
+
+// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet.
+func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
+
+// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
+func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
+
+// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet.
+func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
+
+// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
+func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
+
+// LastError implements SocketOptionsHandler.LastError.
+func (*DefaultSocketOptionsHandler) LastError() *Error {
+ return nil
+}
+
+// SocketOptions contains all the variables which store values for SOL_SOCKET,
+// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
// +stateify savable
type SocketOptions struct {
- // mu protects fields below.
- mu sync.Mutex `state:"nosave"`
- broadcastEnabled bool
+ handler SocketOptionsHandler
+
+ // These fields are accessed and modified using atomic operations.
+
+ // broadcastEnabled determines whether datagram sockets are allowed to
+ // send packets to a broadcast address.
+ broadcastEnabled uint32
+
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control
+ // messages are enabled.
+ passCredEnabled uint32
+
+ // noChecksumEnabled determines whether UDP checksum is disabled while
+ // transmitting for this socket.
+ noChecksumEnabled uint32
+
+ // reuseAddressEnabled determines whether Bind() should allow reuse of
+ // local address.
+ reuseAddressEnabled uint32
+
+ // reusePortEnabled determines whether to permit multiple sockets to be
+ // bound to an identical socket address.
+ reusePortEnabled uint32
+
+ // keepAliveEnabled determines whether TCP keepalive is enabled for this
+ // socket.
+ keepAliveEnabled uint32
+
+ // multicastLoopEnabled determines whether multicast packets sent over a
+ // non-loopback interface will be looped back. Analogous to inet->mc_loop.
+ multicastLoopEnabled uint32
+
+ // receiveTOSEnabled is used to specify if the TOS ancillary message is
+ // passed with incoming packets.
+ receiveTOSEnabled uint32
+
+ // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
+ // message is passed with incoming packets.
+ receiveTClassEnabled uint32
+
+ // receivePacketInfoEnabled is used to specify if more inforamtion is
+ // provided with incoming packets such as interface index and address.
+ receivePacketInfoEnabled uint32
+
+ // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
+ // being written have an IP header and the endpoint should not attach an IP
+ // header.
+ hdrIncludedEnabled uint32
+
+ // v6OnlyEnabled is used to determine whether an IPv6 socket is to be
+ // restricted to sending and receiving IPv6 packets only.
+ v6OnlyEnabled uint32
+
+ // quickAckEnabled is used to represent the value of TCP_QUICKACK option.
+ // It currently does not have any effect on the TCP endpoint.
+ quickAckEnabled uint32
+
+ // delayOptionEnabled is used to specify if data should be sent out immediately
+ // by the transport protocol. For TCP, it determines if the Nagle algorithm
+ // is on or off.
+ delayOptionEnabled uint32
+
+ // corkOptionEnabled is used to specify if data should be held until segments
+ // are full by the TCP transport protocol.
+ corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
+
+ // mu protects the access to the below fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // linger determines the amount of time the socket should linger before
+ // close. We currently implement this option for TCP socket only.
+ linger LingerOption
+}
+
+// InitHandler initializes the handler. This must be called before using the
+// socket options utility.
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+ so.handler = handler
+}
+
+func storeAtomicBool(addr *uint32, v bool) {
+ var val uint32
+ if v {
+ val = 1
+ }
+ atomic.StoreUint32(addr, val)
}
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
- so.mu.Lock()
- defer so.mu.Unlock()
-
- return so.broadcastEnabled
+ return atomic.LoadUint32(&so.broadcastEnabled) != 0
}
// SetBroadcast sets value for SO_BROADCAST option.
func (so *SocketOptions) SetBroadcast(v bool) {
+ storeAtomicBool(&so.broadcastEnabled, v)
+}
+
+// GetPassCred gets value for SO_PASSCRED option.
+func (so *SocketOptions) GetPassCred() bool {
+ return atomic.LoadUint32(&so.passCredEnabled) != 0
+}
+
+// SetPassCred sets value for SO_PASSCRED option.
+func (so *SocketOptions) SetPassCred(v bool) {
+ storeAtomicBool(&so.passCredEnabled, v)
+}
+
+// GetNoChecksum gets value for SO_NO_CHECK option.
+func (so *SocketOptions) GetNoChecksum() bool {
+ return atomic.LoadUint32(&so.noChecksumEnabled) != 0
+}
+
+// SetNoChecksum sets value for SO_NO_CHECK option.
+func (so *SocketOptions) SetNoChecksum(v bool) {
+ storeAtomicBool(&so.noChecksumEnabled, v)
+}
+
+// GetReuseAddress gets value for SO_REUSEADDR option.
+func (so *SocketOptions) GetReuseAddress() bool {
+ return atomic.LoadUint32(&so.reuseAddressEnabled) != 0
+}
+
+// SetReuseAddress sets value for SO_REUSEADDR option.
+func (so *SocketOptions) SetReuseAddress(v bool) {
+ storeAtomicBool(&so.reuseAddressEnabled, v)
+ so.handler.OnReuseAddressSet(v)
+}
+
+// GetReusePort gets value for SO_REUSEPORT option.
+func (so *SocketOptions) GetReusePort() bool {
+ return atomic.LoadUint32(&so.reusePortEnabled) != 0
+}
+
+// SetReusePort sets value for SO_REUSEPORT option.
+func (so *SocketOptions) SetReusePort(v bool) {
+ storeAtomicBool(&so.reusePortEnabled, v)
+ so.handler.OnReusePortSet(v)
+}
+
+// GetKeepAlive gets value for SO_KEEPALIVE option.
+func (so *SocketOptions) GetKeepAlive() bool {
+ return atomic.LoadUint32(&so.keepAliveEnabled) != 0
+}
+
+// SetKeepAlive sets value for SO_KEEPALIVE option.
+func (so *SocketOptions) SetKeepAlive(v bool) {
+ storeAtomicBool(&so.keepAliveEnabled, v)
+ so.handler.OnKeepAliveSet(v)
+}
+
+// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) GetMulticastLoop() bool {
+ return atomic.LoadUint32(&so.multicastLoopEnabled) != 0
+}
+
+// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) SetMulticastLoop(v bool) {
+ storeAtomicBool(&so.multicastLoopEnabled, v)
+}
+
+// GetReceiveTOS gets value for IP_RECVTOS option.
+func (so *SocketOptions) GetReceiveTOS() bool {
+ return atomic.LoadUint32(&so.receiveTOSEnabled) != 0
+}
+
+// SetReceiveTOS sets value for IP_RECVTOS option.
+func (so *SocketOptions) SetReceiveTOS(v bool) {
+ storeAtomicBool(&so.receiveTOSEnabled, v)
+}
+
+// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) GetReceiveTClass() bool {
+ return atomic.LoadUint32(&so.receiveTClassEnabled) != 0
+}
+
+// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) SetReceiveTClass(v bool) {
+ storeAtomicBool(&so.receiveTClassEnabled, v)
+}
+
+// GetReceivePacketInfo gets value for IP_PKTINFO option.
+func (so *SocketOptions) GetReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0
+}
+
+// SetReceivePacketInfo sets value for IP_PKTINFO option.
+func (so *SocketOptions) SetReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receivePacketInfoEnabled, v)
+}
+
+// GetHeaderIncluded gets value for IP_HDRINCL option.
+func (so *SocketOptions) GetHeaderIncluded() bool {
+ return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
+}
+
+// SetHeaderIncluded sets value for IP_HDRINCL option.
+func (so *SocketOptions) SetHeaderIncluded(v bool) {
+ storeAtomicBool(&so.hdrIncludedEnabled, v)
+}
+
+// GetV6Only gets value for IPV6_V6ONLY option.
+func (so *SocketOptions) GetV6Only() bool {
+ return atomic.LoadUint32(&so.v6OnlyEnabled) != 0
+}
+
+// SetV6Only sets value for IPV6_V6ONLY option.
+//
+// Preconditions: the backing TCP or UDP endpoint must be in initial state.
+func (so *SocketOptions) SetV6Only(v bool) {
+ storeAtomicBool(&so.v6OnlyEnabled, v)
+}
+
+// GetQuickAck gets value for TCP_QUICKACK option.
+func (so *SocketOptions) GetQuickAck() bool {
+ return atomic.LoadUint32(&so.quickAckEnabled) != 0
+}
+
+// SetQuickAck sets value for TCP_QUICKACK option.
+func (so *SocketOptions) SetQuickAck(v bool) {
+ storeAtomicBool(&so.quickAckEnabled, v)
+}
+
+// GetDelayOption gets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) GetDelayOption() bool {
+ return atomic.LoadUint32(&so.delayOptionEnabled) != 0
+}
+
+// SetDelayOption sets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) SetDelayOption(v bool) {
+ storeAtomicBool(&so.delayOptionEnabled, v)
+ so.handler.OnDelayOptionSet(v)
+}
+
+// GetCorkOption gets value for TCP_CORK option.
+func (so *SocketOptions) GetCorkOption() bool {
+ return atomic.LoadUint32(&so.corkOptionEnabled) != 0
+}
+
+// SetCorkOption sets value for TCP_CORK option.
+func (so *SocketOptions) SetCorkOption(v bool) {
+ storeAtomicBool(&so.corkOptionEnabled, v)
+ so.handler.OnCorkOptionSet(v)
+}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
+
+// GetLastError gets value for SO_ERROR option.
+func (so *SocketOptions) GetLastError() *Error {
+ return so.handler.LastError()
+}
+
+// GetOutOfBandInline gets value for SO_OOBINLINE option.
+func (*SocketOptions) GetOutOfBandInline() bool {
+ return true
+}
+
+// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
+// support disabling this option.
+func (*SocketOptions) SetOutOfBandInline(bool) {}
+
+// GetLinger gets value for SO_LINGER option.
+func (so *SocketOptions) GetLinger() LingerOption {
so.mu.Lock()
- defer so.mu.Unlock()
+ linger := so.linger
+ so.mu.Unlock()
+ return linger
+}
- so.broadcastEnabled = v
+// SetLinger sets value for SO_LINGER option.
+func (so *SocketOptions) SetLinger(linger LingerOption) {
+ so.mu.Lock()
+ so.linger = linger
+ so.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..9cc6074da 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 9478f3fb7..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
-//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 6dc9e7859..5ec9b3411 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -333,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..03d7b4e0d 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) {
// Make sure the right remote link address is used.
snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
}
// Check NDP NS packet.
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 177bf5516..317f6871d 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -24,9 +24,16 @@ import (
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -175,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -226,6 +234,8 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
}
// removeEntryLocked removes the specified entry from the neighbor cache.
+//
+// Prerequisite: n.mu and entry.mu MUST be locked.
func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
if entry.neigh.State != Static {
n.dynamic.lru.Remove(entry)
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index ed33418f3..732a299f7 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -80,17 +80,20 @@ func entryDiffOptsWithSort() []cmp.Option {
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 493e48031..32399b4f5 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -258,7 +258,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
case Failed:
e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
e.nic.neigh.removeEntryLocked(e)
})
e.job.Schedule(config.UnreachableTime)
@@ -347,9 +347,10 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -511,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
+
+// doubleLock combines two locks into one while maintaining lock ordering.
+//
+// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
+// neighbor is allowed.
+type doubleLock struct {
+ first, second sync.Locker
+}
+
+// Lock locks both locks in order: first then second.
+func (l *doubleLock) Lock() {
+ l.first.Lock()
+ l.second.Lock()
+}
+
+// Unlock unlocks both locks in reverse order: second then first.
+func (l *doubleLock) Unlock() {
+ l.second.Unlock()
+ l.first.Unlock()
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c2b763325..c497d3932 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -89,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
+// | Stale | Delay | Packet queued | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
@@ -101,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | Failed | Packet queued | | |
// | Failed | | Unreachability timer expired | Delete entry | |
type testEntryEventType uint8
@@ -228,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -3433,6 +3435,146 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryFailedToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.Advance(waitFor)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
+ if got := failedLookups.Value(); got != 0 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
+ }
+
+ e.mu.Lock()
+ // Verify queuing a packet to the entry immediately fails.
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ state := e.neigh.State
+ e.mu.Unlock()
+ if state != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
+ }
+
+ if got := failedLookups.Value(); got != 1 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
+ }
+}
+
func TestEntryFailedGetsDeleted(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3e6ceff28..5d037a27e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -54,18 +54,20 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -80,6 +82,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -170,7 +205,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -182,6 +217,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -265,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -277,9 +316,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -561,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -578,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -637,15 +671,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
+ }
// Parse headers.
netProto := n.stack.NetworkProtocolInstance(protocol)
@@ -686,16 +728,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -848,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -861,13 +904,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2cb13c6fa..b334e27c4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -259,15 +259,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // SizeWithPadding returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- SizeWithPadding() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -279,10 +270,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -291,14 +278,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -739,10 +722,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 53cb6694f..de5fe6ffe 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -18,19 +18,22 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -52,8 +55,16 @@ type Route struct {
// address's assigned status without the NIC.
localAddressNIC *NIC
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
// outgoingNIC is the interface this route uses to write packets.
outgoingNIC *NIC
@@ -71,22 +82,24 @@ type Route struct {
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
- addrWithPrefix := addressEndpoint.AddressWithPrefix()
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
if len(remoteAddr) == 0 {
- remoteAddr = addrWithPrefix.Address
+ remoteAddr = localAddr
}
r := makeRoute(
netProto,
- addrWithPrefix.Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -99,8 +112,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// broadcast it.
if len(gateway) > 0 {
r.NextHop = gateway
- } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -108,11 +121,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
@@ -133,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -159,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -170,6 +190,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
@@ -231,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
@@ -244,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
//
// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
return nil, nil
@@ -254,7 +287,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -272,7 +305,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
+ r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
@@ -280,7 +313,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
+ r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
@@ -309,7 +342,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -317,11 +356,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -375,37 +421,44 @@ func (r *Route) MTU() uint32 {
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.mu.localAddressEndpoint != nil {
+ r.mu.localAddressEndpoint.DecRef()
+ r.mu.localAddressEndpoint = nil
}
}
// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
- }
+func (r *Route) Clone() *Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ newRoute := &Route{
+ RemoteAddress: r.RemoteAddress,
+ LocalAddress: r.LocalAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ NextHop: r.NextHop,
+ NetProto: r.NetProto,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
- return *r
-}
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
+ newRoute.mu.Lock()
+ defer newRoute.mu.Unlock()
+ newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
+ if newRoute.mu.localAddressEndpoint != nil {
+ if !newRoute.mu.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+ }
}
- return l
+ newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
+
+ return newRoute
}
// Stack returns the instance of the Stack that owns this route.
@@ -418,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -428,27 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// isInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) isInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- localAddressEndpoint: r.localAddressEndpoint,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e0025e0a9..026d330c4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -171,6 +171,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1118,6 +1121,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1208,10 +1221,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1235,12 +1248,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
netProto,
- localAddressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -1249,10 +1262,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1261,26 +1274,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1294,7 +1307,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1305,7 +1318,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1317,7 +1330,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
return makeRoute(
netProto,
- addressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
nic, /* outboundNIC */
nic, /* localAddressNIC*/
@@ -1329,9 +1342,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1354,8 +1367,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if needRoute {
gateway = route.Gateway
}
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1391,13 +1404,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1409,7 +1422,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1417,12 +1430,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1810,49 +1823,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1912,7 +1896,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2159,3 +2142,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
}
return protos
}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 61db3164b..457990945 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1603,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1657,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1667,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1683,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2407,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2502,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2536,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -3351,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3370,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3393,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3415,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3436,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3459,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3484,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3520,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4091,10 +4100,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4152,3 +4163,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..a692af20b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5b9043d85..66eb562ba 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,14 +38,15 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
// ops is used to set and get socket options.
ops tcpip.SocketOptions
@@ -64,8 +65,11 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -105,8 +109,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -114,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -189,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
- a := &f.acceptQueue[0]
+ a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
return a, nil, nil
}
@@ -206,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -232,7 +226,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -240,7 +234,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f9e83dd1c..45fa62720 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -49,8 +49,9 @@ const ipv4AddressSize = 4
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
-// Note: to support save / restore, it is important that all tcpip errors have
-// distinct error messages.
+// All errors must have unique msg strings.
+//
+// +stateify savable
type Error struct {
msg string
@@ -247,6 +248,16 @@ func (a Address) WithPrefix() AddressWithPrefix {
}
}
+// Unspecified returns true if the address is unspecified.
+func (a Address) Unspecified() bool {
+ for _, b := range a {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -481,6 +492,14 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
}
// PacketOwner is used to get UID and GID of the packet.
@@ -535,7 +554,7 @@ type Endpoint interface {
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (int64, ControlMessages, *Error)
+ Peek([][]byte) (int64, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -593,10 +612,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt SettableSocketOption) *Error
- // SetSockOptBool sets a socket option, for simple cases where a value
- // has the bool type.
- SetSockOptBool(opt SockOptBool, v bool) *Error
-
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
SetSockOptInt(opt SockOptInt, v int) *Error
@@ -604,10 +619,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt GettableSocketOption) *Error
- // GetSockOptBool gets a socket option for simple cases where a return
- // value has the bool type.
- GetSockOptBool(SockOptBool) (bool, *Error)
-
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
GetSockOptInt(SockOptInt) (int, *Error)
@@ -694,79 +705,6 @@ type WriteOptions struct {
Atomic bool
}
-// SockOptBool represents socket options which values have the bool type.
-type SockOptBool int
-
-const (
- // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be held until segments are full by the TCP transport
- // protocol.
- CorkOption SockOptBool = iota
-
- // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be sent out immediately by the transport protocol. For
- // TCP, it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether TCP keepalive is enabled for this socket.
- KeepaliveEnabledOption
-
- // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether multicast packets sent over a non-loopback interface
- // will be looped back.
- MulticastLoopOption
-
- // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether UDP checksum is disabled for this socket.
- NoChecksumOption
-
- // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether SCM_CREDENTIALS socket control messages are enabled.
- //
- // Only supported on Unix sockets.
- PasscredOption
-
- // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
- QuickAckOption
-
- // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
- // specify if the IPV6_TCLASS ancillary message is passed with incoming
- // packets.
- ReceiveTClassOption
-
- // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
- // if the TOS ancillary message is passed with incoming packets.
- ReceiveTOSOption
-
- // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
- // specify if more inforamtion is provided with incoming packets such as
- // interface index and address.
- ReceiveIPPacketInfoOption
-
- // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether Bind() should allow reuse of local address.
- ReuseAddressOption
-
- // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
- // multiple sockets to be bound to an identical socket address.
- ReusePortOption
-
- // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether an IPv6 socket is to be restricted to sending and receiving
- // IPv6 packets only.
- V6OnlyOption
-
- // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
- // endpoint that all packets being written have an IP header and the
- // endpoint should not attach an IP header.
- IPHdrIncludedOption
-
- // AcceptConnOption is used by GetSockOptBool to indicate if the
- // socket is a listening socket.
- AcceptConnOption
-)
-
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
@@ -1158,14 +1096,6 @@ type RemoveMembershipOption MembershipOption
func (*RemoveMembershipOption) isSettableSocketOption() {}
-// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP out-of-band data is delivered along with the normal in-band data.
-type OutOfBandInlineOption int
-
-func (*OutOfBandInlineOption) isGettableSocketOption() {}
-
-func (*OutOfBandInlineOption) isSettableSocketOption() {}
-
// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
@@ -1215,10 +1145,6 @@ type LingerOption struct {
Timeout time.Duration
}
-func (*LingerOption) isGettableSocketOption() {}
-
-func (*LingerOption) isSettableSocketOption() {}
-
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -1389,6 +1315,18 @@ type ICMPv6PacketStats struct {
// RedirectMsg is the total number of ICMPv6 redirect message packets
// counted.
RedirectMsg *StatCounter
+
+ // MulticastListenerQuery is the total number of Multicast Listener Query
+ // messages counted.
+ MulticastListenerQuery *StatCounter
+
+ // MulticastListenerReport is the total number of Multicast Listener Report
+ // messages counted.
+ MulticastListenerReport *StatCounter
+
+ // MulticastListenerDone is the total number of Multicast Listener Done
+ // messages counted.
+ MulticastListenerDone *StatCounter
}
// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
@@ -1430,6 +1368,10 @@ type ICMPv6SentPacketStats struct {
type ICMPv6ReceivedPacketStats struct {
ICMPv6PacketStats
+ // Unrecognized is the total number of ICMPv6 packets received that the
+ // transport layer does not know how to parse.
+ Unrecognized *StatCounter
+
// Invalid is the total number of ICMPv6 packets received that the
// transport layer could not parse.
Invalid *StatCounter
@@ -1439,25 +1381,90 @@ type ICMPv6ReceivedPacketStats struct {
RouterOnlyPacketsDroppedByHost *StatCounter
}
-// ICMPStats collects ICMP-specific stats (both v4 and v6).
-type ICMPStats struct {
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
// ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
// and a single count of packets which failed to write to the link
// layer.
- V4PacketsSent ICMPv4SentPacketStats
+ PacketsSent ICMPv4SentPacketStats
// ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
// packet type and a single count of invalid packets received.
- V4PacketsReceived ICMPv4ReceivedPacketStats
+ PacketsReceived ICMPv4ReceivedPacketStats
+}
+// ICMPv6Stats collects ICMPv6-specific stats.
+type ICMPv6Stats struct {
// ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
// and a single count of packets which failed to write to the link
// layer.
- V6PacketsSent ICMPv6SentPacketStats
+ PacketsSent ICMPv6SentPacketStats
// ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
// packet type and a single count of invalid packets received.
- V6PacketsReceived ICMPv6ReceivedPacketStats
+ PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // V4 contains the ICMPv4-specifics stats.
+ V4 ICMPv4Stats
+
+ // V6 contains the ICMPv4-specifics stats.
+ V6 ICMPv6Stats
+}
+
+// IGMPPacketStats enumerates counts for all IGMP packet types.
+type IGMPPacketStats struct {
+ // MembershipQuery is the total number of Membership Query messages counted.
+ MembershipQuery *StatCounter
+
+ // V1MembershipReport is the total number of Version 1 Membership Report
+ // messages counted.
+ V1MembershipReport *StatCounter
+
+ // V2MembershipReport is the total number of Version 2 Membership Report
+ // messages counted.
+ V2MembershipReport *StatCounter
+
+ // LeaveGroup is the total number of Leave Group messages counted.
+ LeaveGroup *StatCounter
+}
+
+// IGMPSentPacketStats collects outbound IGMP-specific stats.
+type IGMPSentPacketStats struct {
+ IGMPPacketStats
+
+ // Dropped is the total number of IGMP packets dropped.
+ Dropped *StatCounter
+}
+
+// IGMPReceivedPacketStats collects inbound IGMP-specific stats.
+type IGMPReceivedPacketStats struct {
+ IGMPPacketStats
+
+ // Invalid is the total number of IGMP packets received that IGMP could not
+ // parse.
+ Invalid *StatCounter
+
+ // ChecksumErrors is the total number of IGMP packets dropped due to bad
+ // checksums.
+ ChecksumErrors *StatCounter
+
+ // Unrecognized is the total number of unrecognized messages counted, these
+ // are silently ignored for forward-compatibilty.
+ Unrecognized *StatCounter
+}
+
+// IGMPStats colelcts IGMP-specific stats.
+type IGMPStats struct {
+ // IGMPSentPacketStats contains counts of sent packets by IGMP packet type
+ // and a single count of invalid packets received.
+ PacketsSent IGMPSentPacketStats
+
+ // IGMPReceivedPacketStats contains counts of received packets by IGMP packet
+ // type and a single count of invalid packets received.
+ PacketsReceived IGMPReceivedPacketStats
}
// IPStats collects IP-specific stats (both v4 and v6).
@@ -1665,6 +1672,9 @@ type Stats struct {
// ICMP breaks out ICMP-specific stats (both v4 and v6).
ICMP ICMPStats
+ // IGMP breaks out IGMP-specific stats.
+ IGMP IGMPStats
+
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 1c8e2bc34..c461da137 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -226,3 +226,47 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
}
}
}
+
+func TestAddressUnspecified(t *testing.T) {
+ tests := []struct {
+ addr Address
+ unspecified bool
+ }{
+ {
+ addr: "",
+ unspecified: true,
+ },
+ {
+ addr: "\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01\x00",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x01\x01",
+ unspecified: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) {
+ if got := test.addr.Unspecified(); got != test.unspecified {
+ t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 9b0f3b675..800025fb9 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -25,6 +25,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 421da1add..baaa741cd 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,8 +71,8 @@ func TestInitialLoopbackAddresses(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDispatcher{},
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDispatcher{},
+ AutoGenLinkLocal: true,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
@@ -93,9 +94,10 @@ func TestInitialLoopbackAddresses(t *testing.T) {
}
}
-// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address.
-func TestLoopbackAcceptAllInSubnet(t *testing.T) {
+// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and UDP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
const (
nicID = 1
localPort = 80
@@ -107,7 +109,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
Protocol: header.IPv4ProtocolNumber,
AddressWithPrefix: ipv4Addr,
}
- ipv4Bytes := []byte(ipv4Addr.Address)
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
ipv4Bytes[len(ipv4Bytes)-1]++
otherIPv4Address := tcpip.Address(ipv4Bytes)
@@ -129,7 +131,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to wildcard and send to assigned address",
addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: true,
},
{
@@ -148,7 +150,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
name: "IPv4 bind to other subnet-local address and send to assigned address",
addAddress: ipv4ProtocolAddress,
bindAddr: otherIPv4Address,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: false,
},
{
@@ -161,7 +163,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to assigned address and send to other subnet-local address",
addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4Addr.Address,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
dstAddr: otherIPv4Address,
expectRx: false,
},
@@ -236,13 +238,17 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
}
- if gotPayload, _, err := rep.Read(nil); test.expectRx {
+ var addr tcpip.FullAddress
+ if gotPayload, _, err := rep.Read(&addr); test.expectRx {
if err != nil {
- t.Fatalf("reep.Read(nil): %s", err)
+ t.Fatalf("reep.Read(_): %s", err)
}
if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ if addr.Addr != test.addAddress.AddressWithPrefix.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ }
} else {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
@@ -312,3 +318,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
}
}
+
+// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and TCP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 80
+ )
+
+ ipv4ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
+ ipv4Bytes[len(ipv4Bytes)-1]++
+ otherIPv4Address := tcpip.Address(ipv4Bytes)
+
+ ipv6ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ }
+ ipv6Bytes := []byte(ipv6Addr.Address)
+ ipv6Bytes[len(ipv6Bytes)-1]++
+ otherIPv6Address := tcpip.Address(ipv6Bytes)
+
+ tests := []struct {
+ name string
+ addAddress tcpip.ProtocolAddress
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectAccept bool
+ }{
+ {
+ name: "IPv4 bind to wildcard and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard send to other address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: remoteIPv4Addr,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind to other subnet-local address and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to assigned address and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: false,
+ },
+
+ {
+ name: "IPv6 bind and send to assigned address",
+ addAddress: ipv6ProtocolAddress,
+ bindAddr: ipv6Addr.Address,
+ dstAddr: ipv6Addr.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv6 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv6ProtocolAddress,
+ dstAddr: otherIPv6Address,
+ expectAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer listeningEndpoint.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := listeningEndpoint.Bind(bindAddr); err != nil {
+ t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err)
+ }
+
+ if err := listeningEndpoint.Listen(1); err != nil {
+ t.Fatalf("listeningEndpoint.Listen(1): %s", err)
+ }
+
+ connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer connectingEndpoint.Close()
+
+ connectAddr := tcpip.FullAddress{
+ Addr: test.dstAddr,
+ Port: localPort,
+ }
+ if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ }
+
+ if !test.expectAccept {
+ if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+ return
+ }
+
+ // Wait for the listening endpoint to be "readable". That is, wait for a
+ // new connection.
+ <-ch
+ var addr tcpip.FullAddress
+ if _, _, err := listeningEndpoint.Accept(&addr); err != nil {
+ t.Fatalf("listeningEndpoint.Accept(nil): %s", err)
+ }
+ if addr.Addr != test.addAddress.AddressWithPrefix.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 9d30329f5..2e59f6a42 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -510,10 +510,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
- }
-
+ ep.SocketOptions().SetReuseAddress(true)
ep.SocketOptions().SetBroadcast(true)
bindAddr := tcpip.FullAddress{Port: localPort}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 440cb0352..74fe19e98 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -49,6 +49,7 @@ const (
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
@@ -71,11 +72,9 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -85,7 +84,7 @@ type endpoint struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return &endpoint{
+ ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -96,7 +95,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
- }, nil
+ }
+ ep.ops.InitHandler(ep)
+ return ep, nil
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -129,7 +130,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -142,6 +146,7 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
@@ -267,26 +272,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- if to == nil {
- route = &e.route
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in Route.Resolve()
- // call below.
- e.mu.RUnlock()
- defer e.mu.RLock()
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // Recheck state after lock was re-acquired.
- if e.state != stateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
- }
- } else {
+ route := e.route
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -310,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
@@ -343,26 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.SocketDetachFilterOption:
- return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- }
- return nil
-}
-
-// SetSockOptBool sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
@@ -378,17 +351,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -426,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -857,6 +810,7 @@ func (*endpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 3bff3755a..9faab4b9e 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -60,6 +60,8 @@ type packet struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
@@ -83,8 +85,6 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -107,6 +107,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
+ ep.ops.InitHandler(ep)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -203,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
}
// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -303,26 +304,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- ep.mu.Lock()
- ep.linger = *v
- ep.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -378,26 +368,7 @@ func (ep *endpoint) LastError() *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- ep.mu.Lock()
- *o = ep.linger
- ep.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrNotSupported
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.AcceptConnOption:
- return false, nil
- default:
- return false, tcpip.ErrNotSupported
- }
+ return tcpip.ErrNotSupported
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -551,8 +522,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 4ae1f92ab..87c60bdab 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -58,12 +58,13 @@ type rawPacket struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
- hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -82,10 +83,8 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -114,8 +113,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
- hdrIncluded: !associated,
}
+ e.ops.InitHandler(e)
+ e.ops.SetHeaderIncluded(!associated)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -170,9 +170,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -223,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if opts.To != nil {
+ // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -266,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -296,7 +305,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if e.route.IsResolutionRequired() {
- savedRoute := &e.route
+ savedRoute := e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
e.mu.RUnlock()
@@ -304,7 +313,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
+ if !e.connected || savedRoute != e.route {
e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
@@ -314,7 +323,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return n, ch, err
}
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ n, ch, err := e.finishWrite(payloadBytes, e.route)
e.mu.RUnlock()
return n, ch, err
}
@@ -335,7 +344,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
e.mu.RUnlock()
return n, ch, err
@@ -356,7 +365,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
@@ -382,8 +391,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -393,6 +402,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -516,33 +530,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- e.hdrIncluded = v
- e.mu.Unlock()
- return nil
- }
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -589,33 +585,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- v := e.hdrIncluded
- e.mu.Unlock()
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -756,10 +726,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+// LastError implements tcpip.Endpoint.LastError.
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 518449602..cf232b508 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -45,7 +45,9 @@ go_library(
"rcv.go",
"rcv_state.go",
"reno.go",
+ "reno_recovery.go",
"sack.go",
+ "sack_recovery.go",
"sack_scoreboard.go",
"segment.go",
"segment_heap.go",
@@ -91,7 +93,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- shard_count = 10,
+ shard_count = more_shards,
deps = [
":tcp",
"//pkg/rand",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e5adc383..3e1041cbe 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6Only
+ n.ops.SetV6Only(l.v6Only)
n.ID = s.id
n.boundNICID = s.nicID
n.route = route
@@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
@@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
- v6Only := e.v6only
+ v6Only := e.ops.GetV6Only()
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index ac6d879a7..c944dccc0 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,6 +16,7 @@ package tcp
import (
"encoding/binary"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -133,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int {
return 0
}
- max := seqnum.Size(0xffff)
+ max := seqnum.Size(math.MaxUint16)
s := 0
for wnd > max && s < header.MaxWndScale {
s++
@@ -300,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -361,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -496,7 +497,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
}
@@ -547,7 +548,7 @@ func (h *handshake) start() *tcpip.Error {
}
h.sendSYNOpts = synOpts
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -575,7 +576,6 @@ func (h *handshake) complete() *tcpip.Error {
return err
}
defer timer.stop()
-
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@@ -597,7 +597,7 @@ func (h *handshake) complete() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -631,9 +631,8 @@ func (h *handshake) complete() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
-
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@@ -820,8 +819,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
data = data.Clone(nil)
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
mss := int(gso.MSS)
@@ -865,8 +864,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// network endpoint and under the provided identity.
func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
@@ -941,7 +940,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
@@ -1002,7 +1001,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@@ -1080,7 +1079,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
- if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
@@ -1141,7 +1140,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.HardError = tcpip.ErrAborted
+ e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1286,7 +1285,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
userTimeout := e.userTimeout
e.keepalive.Lock()
- if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
@@ -1323,7 +1322,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
e.keepalive.Unlock()
return
@@ -1353,7 +1352,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@@ -1383,7 +1381,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
e.workerCleanup = true
// Lock released below.
@@ -1638,7 +1636,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
}
extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
if newSyn {
- info := e.EndpointInfo.TransportEndpointInfo
+ info := e.TransportEndpointInfo
newID := info.ID
newID.RemoteAddress = ""
newID.RemotePort = 0
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index a6f25896b..1d1b01a6c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) {
}
}
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
@@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
- nep, _, err := c.EP.Accept(&addr)
+ _, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept(&addr)
+ _, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) {
if addr.Addr != context.TestV6Addr {
t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
}
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
- }
}
func TestV4AcceptOnV4(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4f4f4c65e..bb0795f78 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -310,16 +310,12 @@ type Stats struct {
func (*Stats) IsEndpointStats() {}
// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools.
+// can be queried by monitoring tools. This exists to allow tcp-only state to
+// be exposed.
//
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
-
- // HardError is meaningful only when state is stateError. It stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. HardError is protected by endpoint mu.
- HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// +stateify savable
type endpoint struct {
EndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
// a given tcp processor goroutine.
@@ -386,6 +383,11 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
+ // hardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by endpoint mu.
+ hardError *tcpip.Error `state:".(string)"`
+
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -421,7 +423,10 @@ type endpoint struct {
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
- mu sync.Mutex `state:"nosave"`
+ //
+ // During handshake, mu is locked by the protocol listen goroutine and
+ // released by the handshake completion goroutine.
+ mu sync.CrossGoroutineMutex `state:"nosave"`
ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState()
@@ -436,9 +441,8 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
- v6only bool
isConnectNotified bool
// h stores a reference to the current handshake state if the endpoint is in
@@ -506,24 +510,9 @@ type endpoint struct {
// delay is a boolean (0 is false) and must be accessed atomically.
delay uint32
- // cork holds back segments until full.
- //
- // cork is a boolean (0 is false) and must be accessed atomically.
- cork uint32
-
// scoreboard holds TCP SACK Scoreboard information for this endpoint.
scoreboard *SACKScoreboard
- // The options below aren't implemented, but we remember the user
- // settings because applications expect to be able to set/query these
- // options.
-
- // slowAck holds the negated state of quick ack. It is stubbed out and
- // does nothing.
- //
- // slowAck is a boolean (0 is false) and must be accessed atomically.
- slowAck uint32
-
// segmentQueue is used to hand received segments to the protocol
// goroutine. Segments are queued as long as the queue is not full,
// and dropped when it is.
@@ -685,9 +674,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -701,7 +687,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -850,7 +836,6 @@ func (e *endpoint) recentTimestamp() uint32 {
// +stateify savable
type keepalive struct {
sync.Mutex `state:"nosave"`
- enabled bool
idle time.Duration
interval time.Duration
count int
@@ -884,6 +869,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
+ e.ops.SetQuickAck(true)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -907,7 +895,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptBool(tcpip.DelayOption, true)
+ e.ops.SetDelayOption(true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -1049,7 +1037,8 @@ func (e *endpoint) Close() {
return
}
- if e.linger.Enabled && e.linger.Timeout == 0 {
+ linger := e.SocketOptions().GetLinger()
+ if linger.Enabled && linger.Timeout == 0 {
s := e.EndpointState()
isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
if isResetState {
@@ -1169,7 +1158,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -1279,11 +1272,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) LastError() *tcpip.Error {
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) hardErrorLocked() *tcpip.Error {
+ err := e.hardError
+ e.hardError = nil
+ return err
+}
+
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1291,6 +1293,16 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// LastError implements tcpip.Endpoint.LastError.
+func (e *endpoint) LastError() *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return e.lastErrorLocked()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1312,9 +1324,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.HardError
if s == StateError {
- return buffer.View{}, tcpip.ControlMessages{}, he
+ if err := e.hardErrorLocked(); err != nil {
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
@@ -1370,9 +1384,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
- return 0, e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return 0, err
+ }
+ return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@@ -1478,7 +1496,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1486,10 +1504,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.HardError
+ return 0, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
@@ -1498,9 +1516,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return 0, tcpip.ErrClosedForReceive
}
- return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return 0, tcpip.ErrWouldBlock
}
// Make a copy of vec so we can modify the slide headers.
@@ -1515,7 +1533,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
for len(v) > 0 {
if len(vec) == 0 {
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
if len(vec[0]) == 0 {
vec = vec[1:]
@@ -1530,7 +1548,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
}
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1602,72 +1620,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
return false, false
}
-// SetSockOptBool sets a socket option.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
-
- case tcpip.CorkOption:
- e.LockUser()
- if !v {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- e.UnlockUser()
-
- case tcpip.DelayOption:
- if v {
- atomic.StoreUint32(&e.delay, 1)
- } else {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- }
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
-
- case tcpip.QuickAckOption:
- o := uint32(1)
- if v {
- o = 0
- }
- atomic.StoreUint32(&e.slowAck, o)
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- e.portFlags.TupleOnly = v
- e.UnlockUser()
-
- case tcpip.ReusePortOption:
- e.LockUser()
- e.portFlags.LoadBalanced = v
- e.UnlockUser()
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+}
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+}
- // We only allow this to be set when we're in the initial state.
- if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
+// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
+func (e *endpoint) OnKeepAliveSet(v bool) {
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+}
- e.LockUser()
- e.v6only = v
- e.UnlockUser()
+// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet.
+func (e *endpoint) OnDelayOptionSet(v bool) {
+ if !v {
+ // Handle delayed data.
+ e.sndWaker.Assert()
}
+}
- return nil
+// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet.
+func (e *endpoint) OnCorkOptionSet(v bool) {
+ if !v {
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ }
}
// SetSockOptInt sets a socket option.
@@ -1851,9 +1836,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
-
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(*v)
@@ -1922,11 +1904,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.LockUser()
- e.linger = *v
- e.UnlockUser()
-
default:
return nil
}
@@ -1949,67 +1926,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
-
- case tcpip.CorkOption:
- return atomic.LoadUint32(&e.cork) != 0, nil
-
- case tcpip.DelayOption:
- return atomic.LoadUint32(&e.delay) != 0, nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- return v, nil
-
- case tcpip.QuickAckOption:
- v := atomic.LoadUint32(&e.slowAck) == 0
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- v := e.portFlags.TupleOnly
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.LockUser()
- v := e.portFlags.LoadBalanced
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.LockUser()
- v := e.v6only
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.MulticastLoopOption:
- return true, nil
-
- case tcpip.AcceptConnOption:
- e.LockUser()
- defer e.UnlockUser()
-
- return e.EndpointState() == StateListen, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -2120,10 +2036,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
- *o = 1
-
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
@@ -2152,11 +2064,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
- case *tcpip.LingerOption:
- e.LockUser()
- *o = e.linger
- e.UnlockUser()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2166,7 +2073,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -2243,7 +2150,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@@ -2417,7 +2327,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
// Call cleanupLocked to free up any reservations.
e.cleanupLocked()
@@ -2697,7 +2607,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// v6only set to false.
if netProto == header.IPv6ProtocolNumber {
stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
- alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4
if alsoBindToV4 {
netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
@@ -2782,7 +2692,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
- // land at the Dispatcher which then can either delivery using the
+ // land at the Dispatcher which then can either deliver using the
// worker go routine or directly do the invoke the tcp processing inline
// based on the state of the endpoint.
}
@@ -3079,6 +2989,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
@@ -3161,7 +3072,7 @@ func (e *endpoint) State() uint32 {
func (e *endpoint) Info() tcpip.EndpointInfo {
e.LockUser()
// Make a copy of the endpoint info.
- ret := e.EndpointInfo
+ ret := e.TransportEndpointInfo
e.UnlockUser()
return &ret
}
@@ -3187,6 +3098,7 @@ func (e *endpoint) Wait() {
}
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index bb901c0f8..ba67176b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -321,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
-func (e *EndpointInfo) saveHardError() string {
- if e.HardError == nil {
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
return ""
}
- return e.HardError.String()
+ return e.hardError.String()
}
// loadHardError is invoked by stateify.
-func (e *EndpointInfo) loadHardError(s string) {
+func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
- e.HardError = tcpip.StringToError(s)
+ e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..672159eed 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 8e0b7c843..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -16,6 +16,7 @@ package tcp
import (
"container/heap"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -48,6 +49,10 @@ type receiver struct {
rcvWndScale uint8
+ // prevBufused is the snapshot of endpoint rcvBufUsed taken when we
+ // advertise a receive window.
+ prevBufUsed int
+
closed bool
// pendingRcvdSegments is bounded by the receive buffer size of the
@@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// outgoing packets, we should use what we have advertised for acceptability
// test.
scaledWindowSize := r.rcvWnd >> r.rcvWndScale
- if scaledWindowSize > 0xffff {
+ if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
- scaledWindowSize = 0xffff
+ scaledWindowSize = math.MaxUint16
}
advertisedWindowSize := scaledWindowSize << r.rcvWndScale
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
@@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) {
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
+ unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ bufUsed := r.ep.receiveBufferUsed()
+
+ // Grow the right edge of the window only for payloads larger than the
+ // the segment overhead OR if the application is actively consuming data.
+ //
+ // Avoiding growing the right edge otherwise, addresses a situation below:
+ // An application has been slow in reading data and we have burst of
+ // incoming segments lengths < segment overhead. Here, our available free
+ // memory would reduce drastically when compared to the advertised receive
+ // window.
+ //
+ // For example: With incoming 512 bytes segments, segment overhead of
+ // 552 bytes (at the time of writing this comment), with receive window
+ // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0
+ // when the curWnd is still 19436 bytes, because for every incoming segment
+ // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1),
+ // while curWnd would reduce by 512 bytes.
+ // Such a situation causes us to keep tail dropping the incoming segments
+ // and never advertise zero receive window to the peer.
+ //
+ // Linux does a similar check for minimal sk_buff size (128):
+ // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783
+ //
+ // Also, if the application is reading the data, we keep growing the right
+ // edge, as we are still advertising a window that we think can be serviced.
+ toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
+
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
@@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP rcvNxt rcvAcc new rcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
// If the new window moves the right edge, then update rcvAcc.
r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
@@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// receiver's estimated RTT.
r.rcvWnd = newWnd
r.rcvWUP = r.rcvNxt
+ r.prevBufUsed = bufUsed
scaledWnd := r.rcvWnd >> r.rcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
+
+ // If we started off with a window larger than what can he held in
+ // the 16bit window field, we ceil the value to the max value.
+ if scaledWnd > math.MaxUint16 {
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
+ }
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go
new file mode 100644
index 000000000..2aa708e97
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno_recovery.go
@@ -0,0 +1,67 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoRecovery stores the variables related to TCP Reno loss recovery
+// algorithm.
+//
+// +stateify savable
+type renoRecovery struct {
+ s *sender
+}
+
+func newRenoRecovery(s *sender) *renoRecovery {
+ return &renoRecovery{s: s}
+}
+
+func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ ack := rcvdSeg.ackNumber
+ snd := rr.s
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window {
+ return
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if !fastRetransmit && ack == snd.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if snd.sndCwnd < snd.fr.maxCwnd {
+ snd.sndCwnd++
+ }
+ return
+ }
+
+ // A partial ack was received. Retransmit this packet and remember it
+ // so that we don't retransmit it again.
+ //
+ // We don't inflate the window because we're putting the same packet
+ // back onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ snd.fr.first = ack
+ snd.dupAckCount = 0
+ snd.resendSegment()
+}
diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go
new file mode 100644
index 000000000..7e813fa96
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_recovery.go
@@ -0,0 +1,120 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+
+// sackRecovery stores the variables related to TCP SACK loss recovery
+// algorithm.
+//
+// +stateify savable
+type sackRecovery struct {
+ s *sender
+}
+
+func newSACKRecovery(s *sender) *sackRecovery {
+ return &sackRecovery{s: s}
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ snd := sr.s
+ snd.SetPipe()
+
+ if smss := int(snd.ep.scoreboard.SMSS()); limit > smss {
+ // Cap segment size limit to s.smss as SACK recovery requires
+ // that all retransmissions or new segments send during recovery
+ // be of <= SMSS.
+ limit = smss
+ }
+
+ nextSegHint := snd.writeList.Front()
+ for snd.outstanding < snd.sndCwnd {
+ var nextSeg *segment
+ var rescueRtx bool
+ nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint)
+ if nextSeg == nil {
+ return dataSent
+ }
+ if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ // New data being sent.
+
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ //
+ // We pass s.smss as the limit as the Step 2) requires that
+ // new data sent should be of size s.smss or less.
+ if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent {
+ return dataSent
+ }
+ dataSent = true
+ snd.outstanding++
+ snd.writeNext = nextSeg.Next()
+ continue
+ }
+
+ // Now handle the retransmission case where we matched either step 1,3 or 4
+ // of the NextSeg algorithm.
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ snd.outstanding++
+ dataSent = true
+ snd.sendSegment(nextSeg)
+
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if rescueRtx {
+ // We do the last part of rule (4) of NextSeg here to update
+ // RescueRxt as until this point we don't know if we are going
+ // to use the rescue transmission.
+ snd.fr.rescueRxt = snd.fr.last
+ } else {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ snd.fr.highRxt = segEnd - 1
+ }
+ }
+ return dataSent
+}
+
+func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ snd := sr.s
+ if fastRetransmit {
+ snd.resendSegment()
+ }
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ end := snd.sndUna.Add(snd.sndWnd)
+ dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end)
+ snd.postXmit(dataSent)
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 2091989cc..5ef73ec74 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -204,7 +204,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return segSize + s.data.Size()
+ return SegSize + s.data.Size()
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
index 0ab7b8f56..392ff0859 100644
--- a/pkg/tcpip/transport/tcp/segment_unsafe.go
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -19,5 +19,6 @@ import (
)
const (
- segSize = int(unsafe.Sizeof(segment{}))
+ // SegSize is the minimal size of the segment overhead.
+ SegSize = int(unsafe.Sizeof(segment{}))
)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 0e0fdf14c..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math"
"sort"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -92,6 +91,17 @@ type congestionControl interface {
PostRecovery()
}
+// lossRecovery is an interface that must be implemented by any supported
+// loss recovery algorithm.
+type lossRecovery interface {
+ // DoRecovery is invoked when loss is detected and segments need
+ // to be retransmitted. The cumulative or selective ACK is passed along
+ // with the flag which identifies whether the connection entered fast
+ // retransmit with this ACK and to retransmit the first unacknowledged
+ // segment.
+ DoRecovery(rcvdSeg *segment, fastRetransmit bool)
+}
+
// sender holds the state necessary to send TCP segments.
//
// +stateify savable
@@ -108,6 +118,9 @@ type sender struct {
// fr holds state related to fast recovery.
fr fastRecovery
+ // lr is the loss recovery algorithm used by the sender.
+ lr lossRecovery
+
// sndCwnd is the congestion window, in packets.
sndCwnd int
@@ -124,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -276,6 +292,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.cc = s.initCongestionControl(ep.cc)
+ s.lr = s.initLossRecovery()
+
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
@@ -330,6 +348,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon
}
}
+// initLossRecovery initiates the loss recovery algorithm for the sender.
+func (s *sender) initLossRecovery() lossRecovery {
+ if s.ep.sackPermitted {
+ return newSACKRecovery(s)
+ }
+ return newRenoRecovery(s)
+}
+
// updateMaxPayloadSize updates the maximum payload size based on the given
// MTU. If this is in response to "packet too big" control packets (indicated
// by the count argument), it also reduces the number of outstanding packets and
@@ -349,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -371,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -378,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -550,7 +584,7 @@ func (s *sender) retransmitTimerExpired() bool {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
- s.leaveFastRecovery()
+ s.leaveRecovery()
}
s.state = RTORecovery
@@ -606,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -789,7 +823,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -808,7 +842,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
+ if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -913,79 +947,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
return true
}
-// handleSACKRecovery implements the loss recovery phase as described in RFC6675
-// section 5, step C.
-func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
- s.SetPipe()
-
- if smss := int(s.ep.scoreboard.SMSS()); limit > smss {
- // Cap segment size limit to s.smss as SACK recovery requires
- // that all retransmissions or new segments send during recovery
- // be of <= SMSS.
- limit = smss
- }
-
- nextSegHint := s.writeList.Front()
- for s.outstanding < s.sndCwnd {
- var nextSeg *segment
- var rescueRtx bool
- nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint)
- if nextSeg == nil {
- return dataSent
- }
- if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
- // New data being sent.
-
- // Step C.3 described below is handled by
- // maybeSendSegment which increments sndNxt when
- // a segment is transmitted.
- //
- // Step C.3 "If any of the data octets sent in
- // (C.1) are above HighData, HighData must be
- // updated to reflect the transmission of
- // previously unsent data."
- //
- // We pass s.smss as the limit as the Step 2) requires that
- // new data sent should be of size s.smss or less.
- if sent := s.maybeSendSegment(nextSeg, limit, end); !sent {
- return dataSent
- }
- dataSent = true
- s.outstanding++
- s.writeNext = nextSeg.Next()
- continue
- }
-
- // Now handle the retransmission case where we matched either step 1,3 or 4
- // of the NextSeg algorithm.
- // RFC 6675, Step C.4.
- //
- // "The estimate of the amount of data outstanding in the network
- // must be updated by incrementing pipe by the number of octets
- // transmitted in (C.1)."
- s.outstanding++
- dataSent = true
- s.sendSegment(nextSeg)
-
- segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
- if rescueRtx {
- // We do the last part of rule (4) of NextSeg here to update
- // RescueRxt as until this point we don't know if we are going
- // to use the rescue transmission.
- s.fr.rescueRxt = s.fr.last
- } else {
- // RFC 6675, Step C.2
- //
- // "If any of the data octets sent in (C.1) are below
- // HighData, HighRxt MUST be set to the highest sequence
- // number of the retransmitted segment unless NextSeg ()
- // rule (4) was invoked for this retransmission."
- s.fr.highRxt = segEnd - 1
- }
- }
- return dataSent
-}
-
func (s *sender) sendZeroWindowProbe() {
ack, win := s.ep.rcv.getSendParams()
s.unackZeroWindowProbes++
@@ -1014,6 +975,30 @@ func (s *sender) disableZeroWindowProbing() {
s.resendTimer.disable()
}
+func (s *sender) postXmit(dataSent bool) {
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // If the sender has advertized zero receive window and we have
+ // data to be sent out, start zero window probing to query the
+ // the remote for it's receive window size.
+ if s.writeNext != nil && s.sndWnd == 0 {
+ s.enableZeroWindowProbing()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
@@ -1034,55 +1019,29 @@ func (s *sender) sendData() {
}
var dataSent bool
-
- // RFC 6675 recovery algorithm step C 1-5.
- if s.fr.active && s.ep.sackPermitted {
- dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
- } else {
- for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
- if cwndLimit < limit {
- limit = cwndLimit
- }
- if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- // Move writeNext along so that we don't try and scan data that
- // has already been SACKED.
- s.writeNext = seg.Next()
- continue
- }
- if sent := s.maybeSendSegment(seg, limit, end); !sent {
- break
- }
- dataSent = true
- s.outstanding += s.pCount(seg)
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Move writeNext along so that we don't try and scan data that
+ // has already been SACKED.
s.writeNext = seg.Next()
+ continue
}
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
+ s.writeNext = seg.Next()
}
- if dataSent {
- // We sent data, so we should stop the keepalive timer to ensure
- // that no keepalives are sent while there is pending data.
- s.ep.disableKeepaliveTimer()
- }
-
- // If the sender has advertized zero receive window and we have
- // data to be sent out, start zero window probing to query the
- // the remote for it's receive window size.
- if s.writeNext != nil && s.sndWnd == 0 {
- s.enableZeroWindowProbing()
- }
-
- // Enable the timer if we have pending data and it's not enabled yet.
- if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
- s.resendTimer.enable(s.rto)
- }
- // If we have no more pending data, start the keepalive timer.
- if s.sndUna == s.sndNxt {
- s.ep.resetKeepaliveTimer(false)
- }
+ s.postXmit(dataSent)
}
-func (s *sender) enterFastRecovery() {
+func (s *sender) enterRecovery() {
s.fr.active = true
// Save state to reflect we're now in fast recovery.
//
@@ -1090,6 +1049,7 @@ func (s *sender) enterFastRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1104,7 +1064,7 @@ func (s *sender) enterFastRecovery() {
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
-func (s *sender) leaveFastRecovery() {
+func (s *sender) leaveRecovery() {
s.fr.active = false
s.fr.maxCwnd = 0
s.dupAckCount = 0
@@ -1115,57 +1075,6 @@ func (s *sender) leaveFastRecovery() {
s.cc.PostRecovery()
}
-func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
- ack := seg.ackNumber
- // We are in fast recovery mode. Ignore the ack if it's out of
- // range.
- if !ack.InRange(s.sndUna, s.sndNxt+1) {
- return false
- }
-
- // Leave fast recovery if it acknowledges all the data covered by
- // this fast recovery session.
- if s.fr.last.LessThan(ack) {
- s.leaveFastRecovery()
- return false
- }
-
- if s.ep.sackPermitted {
- // When SACK is enabled we let retransmission be governed by
- // the SACK logic.
- return false
- }
-
- // Don't count this as a duplicate if it is carrying data or
- // updating the window.
- if seg.logicalLen() != 0 || s.sndWnd != seg.window {
- return false
- }
-
- // Inflate the congestion window if we're getting duplicate acks
- // for the packet we retransmitted.
- if ack == s.fr.first {
- // We received a dup, inflate the congestion window by 1 packet
- // if we're not at the max yet. Only inflate the window if
- // regular FastRecovery is in use, RFC6675 does not require
- // inflating cwnd on duplicate ACKs.
- if s.sndCwnd < s.fr.maxCwnd {
- s.sndCwnd++
- }
- return false
- }
-
- // A partial ack was received. Retransmit this packet and
- // remember it so that we don't retransmit it again. We don't
- // inflate the window because we're putting the same packet back
- // onto the wire.
- //
- // N.B. The retransmit timer will be reset by the caller.
- s.fr.first = ack
- s.dupAckCount = 0
- return true
-}
-
// isAssignedSequenceNumber relies on the fact that we only set flags once a
// sequencenumber is assigned and that is only done right before we send the
// segment. As a result any segment that has a non-zero flag has a valid
@@ -1228,14 +1137,11 @@ func (s *sender) SetPipe() {
s.outstanding = pipe
}
-// checkDuplicateAck is called when an ack is received. It manages the state
-// related to duplicate acks and determines if a retransmit is needed according
-// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+// detectLoss is called when an ack is received and returns whether a loss is
+// detected. It manages the state related to duplicate acks and determines if
+// a retransmit is needed according to the rules in RFC 6582 (NewReno).
+func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
ack := seg.ackNumber
- if s.fr.active {
- return s.handleFastRecovery(seg)
- }
// We're not in fast recovery yet. A segment is considered a duplicate
// only if it doesn't carry any data and doesn't update the send window,
@@ -1266,14 +1172,14 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
//
// We only do the check here, the incrementing of last to the highest
- // sequence number transmitted till now is done when enterFastRecovery
+ // sequence number transmitted till now is done when enterRecovery
// is invoked.
if !s.fr.last.LessThan(seg.ackNumber) {
s.dupAckCount = 0
return false
}
s.cc.HandleNDupAcks()
- s.enterFastRecovery()
+ s.enterRecovery()
s.dupAckCount = 0
return true
}
@@ -1313,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1415,14 +1322,23 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.SetPipe()
}
- // Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(rcvdSeg)
+ ack := rcvdSeg.ackNumber
+ fastRetransmit := false
+ // Do not leave fast recovery, if the ACK is out of range.
+ if s.fr.active {
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) {
+ s.leaveRecovery()
+ }
+ } else {
+ // Detect loss by counting the duplicates and enter recovery.
+ fastRetransmit = s.detectLoss(rcvdSeg)
+ }
// Stash away the current window size.
s.sndWnd = rcvdSeg.window
- ack := rcvdSeg.ackNumber
-
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
@@ -1477,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1496,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
@@ -1539,19 +1457,24 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.resendTimer.disable()
}
}
+
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if rtx {
- s.resendSegment()
+ if s.fr.active {
+ s.lr.DoRecovery(rcvdSeg, fastRetransmit)
+ // When SACK is enabled data sending is governed by steps in
+ // RFC 6675 Section 5 recovery steps A-C.
+ // See: https://tools.ietf.org/html/rfc6675#section-5.
+ if s.ep.sackPermitted {
+ return
+ }
}
// Send more data now that some of the pending data has been ack'd, or
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
- s.sendData()
- }
+ s.sendData()
}
// sendSegment sends the specified segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ef7f5719f..faf0c0ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) {
expected++
}
}
+
+// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestSACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ sendAndReceive(t, c, 8)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9f0fb41e3..351a5e4f5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.LastError(); err != tcpip.ErrAborted {
- t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
- }
// Call Connect again to retreive the handshake failure status
// and stats updates.
@@ -267,7 +264,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
}
// Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
}
@@ -1935,6 +1932,84 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+// Test the stack receive window advertisement on receiving segments smaller than
+// segment overhead. It tests for the right edge of the window to not grow when
+// the endpoint is not being read from.
+func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize,
+ Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
+ }
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Bump up the receive buffer size such that, when the receive window grows,
+ // the scaled window exceeds maxUint16.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
+ }
+
+ // Keep the payload size < segment overhead and such that it is a multiple
+ // of the window scaled value. This enables the test to perform equality
+ // checks on the incoming receive window.
+ payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadLen := seqnum.Size(len(payload))
+ iss := seqnum.Value(789)
+ seqNum := iss.Add(1)
+
+ // Send payload to the endpoint and return the advertised receive window
+ // from the endpoint.
+ getIncomingRcvWnd := func() uint32 {
+ c.SendPacket(payload, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ Flags: header.TCPFlagAck,
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(payloadLen)
+
+ pkt := c.GetPacket()
+ return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
+ }
+
+ // Read the advertised receive window with the ACK for payload.
+ rcvWnd := getIncomingRcvWnd()
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Read the data so that the subsequent ACK from the endpoint
+ // grows the right edge of the window.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("got Read(nil) = %s", err)
+ }
+
+ // Check if we have received max uint16 as our advertised
+ // scaled window now after a read above.
+ maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
+ if got, want := getIncomingRcvWnd(), maxRcv; got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+}
+
func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -2532,10 +2607,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, true)
+ ep.SocketOptions().SetCorkOption(true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, false)
+ ep.SocketOptions().SetCorkOption(false)
},
},
}
@@ -2627,7 +2702,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2675,7 +2750,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2708,7 +2783,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptBool(tcpip.DelayOption, false)
+ c.EP.SocketOptions().SetDelayOption(false)
// Check that data is received.
second := c.GetPacket()
@@ -2745,8 +2820,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
}
for _, test := range tests {
@@ -3198,6 +3273,11 @@ loop:
case tcpip.ErrWouldBlock:
select {
case <-ch:
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@@ -3207,14 +3287,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
- }
+
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
-
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
@@ -4150,7 +4226,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Check that peek works.
peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
+ n, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
@@ -4176,7 +4252,7 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -4193,9 +4269,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4205,9 +4279,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4218,9 +4290,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4233,9 +4303,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4246,9 +4314,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4261,9 +4327,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4656,13 +4720,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(false)
default:
t.Fatalf("unknown network: '%s'", network)
}
@@ -4998,9 +5058,7 @@ func TestKeepalive(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -6118,10 +6176,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Introduce a 25ms latency by delaying the first byte.
latency := 25 * time.Millisecond
time.Sleep(latency)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ // Send an initial payload with atleast segment overhead size. The receive
+ // window would not grow for smaller segments.
+ rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
+
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
@@ -6394,10 +6455,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
- gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
- }
+ gotDelayOption := ep.SocketOptions().GetDelayOption()
if gotDelayOption != wantDelayOption {
t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
@@ -7250,9 +7308,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// Set userTimeout to be the duration to be 1 keepalive
// probes. Which means that after the first probe is sent
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e6aa4fc4b..ee55f030c 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetV6Only(v6only)
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
@@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ TransportProtocol: tcp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c78549424..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,8 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 57976d4e3..763d1d654 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,8 +16,8 @@ package udp
import (
"fmt"
+ "sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -30,10 +30,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -77,6 +78,7 @@ func (s EndpointState) String() string {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
@@ -94,21 +96,20 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
- v6only bool
ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
- multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -122,17 +123,6 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- // receiveTOS determines if the incoming IPv4 TOS header field is passed
- // as ancillary data to ControlMessages on Read.
- receiveTOS bool
-
- // receiveTClass determines if the incoming IPv6 TClass header field is
- // passed as ancillary data to ControlMessages on Read.
- receiveTClass bool
-
- // receiveIPPacketInfo determines if the packet info is returned by Read.
- receiveIPPacketInfo bool
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -154,9 +144,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -188,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -210,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState() returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
// UniqueID implements stack.TransportEndpoint.UniqueID.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -235,7 +237,7 @@ func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
+ switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
@@ -258,10 +260,13 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
- e.state = StateClosed
+ e.setEndpointState(StateClosed)
e.mu.Unlock()
@@ -303,24 +308,23 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
HasTimestamp: true,
Timestamp: p.timestamp,
}
- e.mu.RLock()
- receiveTOS := e.receiveTOS
- receiveTClass := e.receiveTClass
- receiveIPPacketInfo := e.receiveIPPacketInfo
- e.mu.RUnlock()
- if receiveTOS {
+ if e.ops.GetReceiveTOS() {
cm.HasTOS = true
cm.TOS = p.tos
}
- if receiveTClass {
+ if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
- if receiveIPPacketInfo {
+ if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
return p.data.ToView(), cm, nil
}
@@ -330,7 +334,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateConnected:
return false, nil
@@ -352,7 +356,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return true, nil
}
@@ -367,9 +371,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
- if isBroadcastOrMulticast(localAddr) {
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
}
@@ -384,9 +388,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -429,7 +433,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
to := opts.To
e.mu.RLock()
- defer e.mu.RUnlock()
+ lockReleased := false
+ defer func() {
+ if lockReleased {
+ return
+ }
+ e.mu.RUnlock()
+ }()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
@@ -448,36 +458,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
- var dstPort uint16
- if to == nil {
- route = &e.route
- dstPort = e.dstPort
- resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
- // Promote lock to exclusive if using a shared route, given that it may
- // need to change in Route.Resolve() call below.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- if err == nil && route.IsResolutionRequired() {
- ch, err = route.Resolve(waker)
- }
-
- e.mu.Unlock()
- e.mu.RLock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- return
- }
- } else {
+ route := e.route
+ dstPort := e.dstPort
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -505,9 +488,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
- resolve = route.Resolve
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
@@ -515,7 +497,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if route.IsResolutionRequired() {
- if ch, err := resolve(nil); err != nil {
+ if ch, err := route.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -541,77 +523,46 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ localPort := e.ID.LocalPort
+ sendTOS := e.sendTOS
+ owner := e.owner
+ noChecksum := e.SocketOptions().GetNoChecksum()
+ lockReleased = true
+ e.mu.RUnlock()
+
+ // Do not hold lock when sending as loopback is synchronous and if the UDP
+ // datagram ends up generating an ICMP response then it can result in a
+ // deadlock where the ICMP response handling ends up acquiring this endpoint's
+ // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
+ // deadlock if another caller is trying to acquire e.mu in exclusive mode w/
+ // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
+ // lock can be eventually acquired.
+ //
+ // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
+ // locking is prohibited.
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = v
- e.mu.Unlock()
-
- case tcpip.NoChecksumOption:
- e.mu.Lock()
- e.noChecksum = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTOSOption:
- e.mu.Lock()
- e.receiveTOS = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrNotSupported
- }
-
- e.mu.Lock()
- e.receiveTClass = v
- e.mu.Unlock()
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.portFlags.MostRecent = v
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.portFlags.LoadBalanced = v
- e.mu.Unlock()
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
+}
- e.v6only = v
- }
- return nil
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
@@ -814,90 +765,10 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
}
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption:
- return false, nil
-
- case tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.NoChecksumOption:
- e.mu.RLock()
- v := e.noChecksum
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTOSOption:
- e.mu.RLock()
- v := e.receiveTOS
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrNotSupported
- }
-
- e.mu.RLock()
- v := e.receiveTClass
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.portFlags.MostRecent
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.portFlags.LoadBalanced
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.RLock()
- v := e.v6only
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -972,11 +843,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
- case *tcpip.LingerOption:
- e.mu.RLock()
- *o = e.linger
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1036,7 +902,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -1048,7 +914,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return nil
}
var (
@@ -1071,7 +937,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = StateBound
+ e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
@@ -1079,14 +945,14 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- e.state = StateInitial
+ e.setEndpointState(StateInitial)
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil
@@ -1104,7 +970,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
var localPort uint16
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateBound, StateConnected:
localPort = e.ID.LocalPort
@@ -1139,7 +1005,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -1147,7 +1013,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -1173,7 +1039,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
- e.state = StateConnected
+ e.setEndpointState(StateConnected)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1195,7 +1061,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != StateBound && e.state != StateConnected {
+ if state := e.EndpointState(); state != StateBound && state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -1246,7 +1112,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1259,7 +1125,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
@@ -1267,7 +1133,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicID := addr.NIC
- if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
// A local unicast address was specified, verify that it's valid.
nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicID == 0 {
@@ -1290,7 +1156,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1322,7 +1188,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
addr := e.ID.LocalAddress
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
addr = e.route.LocalAddress
}
@@ -1338,7 +1204,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1393,7 +1259,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1402,6 +1267,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
+ // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
+ // packets at "Parse" instead of when handling a packet.
+ pkt.Data.CapLength(int(hdr.PayloadLength()))
+
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1435,7 +1304,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
+ },
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
},
}
packet.data = pkt.Data
@@ -1470,25 +1344,20 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
- e.mu.RLock()
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
e.lastErrorMu.Lock()
e.lastError = tcpip.ErrConnectionRefused
e.lastErrorMu.Unlock()
- e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventErr)
return
}
- e.mu.RUnlock()
}
}
// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
@@ -1508,14 +1377,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.Wait.
func (*endpoint) Wait() {}
-func isBroadcastOrMulticast(a tcpip.Address) bool {
- return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 858c99a45..13b72dc88 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != StateBound && e.state != StateConnected {
+ state := e.EndpointState()
+ if state != StateBound && state != StateConnected {
return
}
@@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ if state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 764ad0857..08980c298 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -32,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +56,7 @@ const (
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
+ invalidPort = 8192
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
@@ -295,7 +298,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
})
}
@@ -360,9 +364,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetV6Only(true)
} else if flow.isBroadcast() {
c.ep.SocketOptions().SetBroadcast(true)
}
@@ -451,12 +453,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -972,7 +974,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
// provided.
func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
+ return testWriteAndVerifyInternal(c, flow, true, checkers...)
}
// testWriteWithoutDestination sends a packet of the given test flow from the
@@ -981,10 +983,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker
// checker functions provided.
func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
+ return testWriteAndVerifyInternal(c, flow, false, checkers...)
}
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
@@ -1006,6 +1008,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
c.checkEndpointWriteStats(1, epstats, err)
+ return payload
+}
+
+func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -1150,6 +1158,39 @@ func TestV4WriteOnConnected(t *testing.T) {
testWriteWithoutDestination(c, unicastV4)
}
+func TestWriteOnConnectedInvalidPort(t *testing.T) {
+ protocols := map[string]tcpip.NetworkProtocolNumber{
+ "ipv4": ipv4.ProtocolNumber,
+ "ipv6": ipv6.ProtocolNumber,
+ }
+ for name, pn := range protocols {
+ t.Run(name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(pn)
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ }
+ if got, want := n, int64(len(payload)); got != want {
+ c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
+ }
+
+ if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
+ })
+ }
+}
+
// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
// that is bound to a V4 multicast address.
func TestWriteOnBoundToV4Multicast(t *testing.T) {
@@ -1372,9 +1413,7 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
- t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
- }
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
NIC: 1,
@@ -1389,6 +1428,93 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
+func TestReadRecvOriginalDstAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedOriginalDstAddr tcpip.FullAddress
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%#v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
+
+ testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1412,16 +1538,12 @@ func TestNoChecksum(t *testing.T) {
c.createEndpointForFlow(flow)
// Disable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(true)
// This option is effective on IPv4 only.
testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
// Enable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(false)
testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
})
}
@@ -1591,13 +1713,15 @@ func TestSetTClass(t *testing.T) {
}
func TestReceiveTosTClass(t *testing.T) {
+ const RcvTOSOpt = "ReceiveTosOption"
+ const RcvTClassOpt = "ReceiveTClassOption"
+
testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
+ name string
+ tests []testFlow
}{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
+ {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1606,29 +1730,32 @@ func TestReceiveTosTClass(t *testing.T) {
defer c.cleanup()
c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ var optionGetter func() bool
+ var optionSetter func(bool)
+ switch name {
+ case RcvTOSOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTOS
+ optionSetter = c.ep.SocketOptions().SetReceiveTOS
+ case RcvTClassOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTClass
+ optionSetter = c.ep.SocketOptions().SetReceiveTClass
+ default:
+ t.Fatalf("unkown test variant: %s", name)
}
+
+ // Verify that setting and reading the option works.
+ v := optionGetter()
// Test for expected default value.
if v != false {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
- }
+ optionSetter(want)
+ got := optionGetter()
if got != want {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
}
@@ -1638,10 +1765,10 @@ func TestReceiveTosTClass(t *testing.T) {
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %s", err)
}
- switch option {
- case tcpip.ReceiveTClassOption:
+ switch name {
+ case RcvTClassOpt:
testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
+ case RcvTOSOpt:
testRead(c, flow, checker.ReceiveTOS(testTOS))
default:
t.Fatalf("unknown test variant: %s", name)
@@ -1788,27 +1915,31 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
+ payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+
+ if got, want := len(origDgram), wantDgramLen; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantDgramLen))
+
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -1883,20 +2014,23 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
+ if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" {
+ t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff)
}
})
}
@@ -1955,12 +2089,12 @@ func TestShortHeader(t *testing.T) {
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -2409,3 +2543,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
+
+func TestReceiveShortLength(t *testing.T) {
+ flows := []testFlow{unicastV4, unicastV6}
+ for _, flow := range flows {
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ payload := newPayload()
+ extraBytes := []byte{1, 2, 3, 4}
+ h := flow.header4Tuple(incoming)
+ var buf buffer.View
+ var proto tcpip.NetworkProtocolNumber
+
+ // Build packets with extra bytes not accounted for in the UDP length
+ // field.
+ var udp header.UDP
+ if flow.isV4() {
+ buf = c.buildV4Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv4(buf)
+ ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ proto = ipv4.ProtocolNumber
+ udp = ip.Payload()
+ } else {
+ buf = c.buildV6Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv6(buf)
+ ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
+ proto = ipv6.ProtocolNumber
+ udp = ip.Payload()
+ }
+
+ if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
+ t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
+ }
+
+ c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Try to receive the data.
+ v, _, err := c.ep.Read(nil)
+ if err != nil {
+ t.Fatalf("c.ep.Read(nil): %s", err)
+ }
+
+ // Check the payload is read back without extra bytes.
+ if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
+ t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 70945f234..e41769017 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -54,14 +54,20 @@ func ResolvePath(executable string) string {
}
}
+ // Favor /usr/local/bin, if it exists.
+ localBin := fmt.Sprintf("/usr/local/bin/%s", executable)
+ if _, err := os.Stat(localBin); err == nil {
+ return localBin
+ }
+
// Try to find via the path.
- guess, err := exec.LookPath(executable)
+ guess, _ := exec.LookPath(executable)
if err == nil {
return guess
}
- // Return a default path.
- return fmt.Sprintf("/usr/local/bin/%s", executable)
+ // Return a bare path; this generates a suitable error.
+ return executable
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 64d17f661..2bf0a22ff 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -17,6 +17,7 @@ package dockerutil
import (
"bytes"
"context"
+ "errors"
"fmt"
"io/ioutil"
"net"
@@ -351,6 +352,9 @@ func (c *Container) SandboxPid(ctx context.Context) (int, error) {
return resp.ContainerJSONBase.State.Pid, nil
}
+// ErrNoIP indicates that no IP address is available.
+var ErrNoIP = errors.New("no IP available")
+
// FindIP returns the IP address of the container.
func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
resp, err := c.client.ContainerInspect(ctx, c.id)
@@ -365,7 +369,7 @@ func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
}
if ip == nil {
- return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ return net.IP{}, ErrNoIP
}
return ip, nil
}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
index 4c739c9e9..bf968acec 100644
--- a/pkg/test/dockerutil/exec.go
+++ b/pkg/test/dockerutil/exec.go
@@ -77,11 +77,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc
return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
}
- if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
- hijack.Close()
- return Process{}, fmt.Errorf("exec start failed with err: %v", err)
- }
-
return Process{
container: c,
execid: resp.ID,
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 49ab87c58..fdd416b5e 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -36,7 +36,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync/atomic"
"syscall"
"testing"
"time"
@@ -49,7 +48,10 @@ import (
)
var (
- checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
+ totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
+ isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet")
)
// IsCheckpointSupported returns the relevant command line flag.
@@ -57,6 +59,11 @@ func IsCheckpointSupported() bool {
return *checkpoint
}
+// IsRunningWithHostNet returns the relevant command line flag.
+func IsRunningWithHostNet() bool {
+ return *isRunningWithHostNet
+}
+
// ImageByName mangles the image name used locally. This depends on the image
// build infrastructure in images/ and tools/vm.
func ImageByName(name string) string {
@@ -249,14 +256,25 @@ func writeSpec(dir string, spec *specs.Spec) error {
// idRandomSrc is a pseudo random generator used to in RandomID.
var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
+// idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used
+// concurrently in differnt goroutines.
+var idRandomSrcMtx sync.Mutex
+
// RandomID returns 20 random bytes following the given prefix.
func RandomID(prefix string) string {
// Read 20 random bytes.
b := make([]byte, 20)
+ // Rand.Read is not safe for concurrent use. Packetimpact tests can be run in
+ // parallel now, so we have to protect the Read with a mutex. Otherwise we'll
+ // run into name conflicts.
+ // https://golang.org/pkg/math/rand/#Rand.Read
+ idRandomSrcMtx.Lock()
// "[Read] always returns len(p) and a nil error." --godoc
if _, err := idRandomSrc.Read(b); err != nil {
+ idRandomSrcMtx.Unlock()
panic("rand.Read failed: " + err.Error())
}
+ idRandomSrcMtx.Unlock()
if prefix != "" {
prefix = prefix + "-"
}
@@ -417,33 +435,35 @@ func StartReaper() func() {
// WaitUntilRead reads from the given reader until the wanted string is found
// or until timeout.
-func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error {
+func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error {
sc := bufio.NewScanner(r)
- if split != nil {
- sc.Split(split)
- }
// done must be accessed atomically. A value greater than 0 indicates
// that the read loop can exit.
- var done uint32
- doneCh := make(chan struct{})
+ doneCh := make(chan bool)
+ defer close(doneCh)
go func() {
for sc.Scan() {
t := sc.Text()
if strings.Contains(t, want) {
- atomic.StoreUint32(&done, 1)
- close(doneCh)
- break
+ doneCh <- true
+ return
}
- if atomic.LoadUint32(&done) > 0 {
- break
+ select {
+ case <-doneCh:
+ return
+ default:
}
}
+ doneCh <- false
}()
+
select {
case <-time.After(timeout):
- atomic.StoreUint32(&done, 1)
return fmt.Errorf("timeout waiting to read %q", want)
- case <-doneCh:
+ case res := <-doneCh:
+ if !res {
+ return fmt.Errorf("reader closed while waiting to read %q", want)
+ }
return nil
}
}
@@ -509,7 +529,8 @@ func TouchShardStatusFile() error {
}
// TestIndicesForShard returns indices for this test shard based on the
-// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as
+// the passed partition flags.
//
// If either of the env vars are not present, then the function will return all
// tests. If there are more shards than there are tests, then the returned list
@@ -534,6 +555,11 @@ func TestIndicesForShard(numTests int) ([]int, error) {
}
}
+ // Combine with the partitions.
+ partitionSize := shardTotal
+ shardTotal = (*totalPartitions) * shardTotal
+ shardIndex = partitionSize*(*partition-1) + shardIndex
+
// Calculate!
var indices []int
numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 9b1e7a085..79db8895b 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) {
return n, err
}
-// Writer implements io.Writer.Write.
+// Write implements io.Writer.Write.
func (rw *IOReadWriter) Write(src []byte) (int, error) {
n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts)
end, ok := rw.Addr.AddLength(uint64(n))
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 08519d986..83d4f893a 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -119,7 +119,10 @@ type EntryCallback interface {
// The callback is supposed to perform minimal work, and cannot call
// any method on the queue itself because it will be locked while the
// callback is running.
- Callback(e *Entry)
+ //
+ // The mask indicates the events that occurred and that the entry is
+ // interested in.
+ Callback(e *Entry, mask EventMask)
}
// Entry represents a waiter that can be add to the a wait queue. It can
@@ -140,7 +143,7 @@ type channelCallback struct {
}
// Callback implements EntryCallback.Callback.
-func (c *channelCallback) Callback(*Entry) {
+func (c *channelCallback) Callback(*Entry, EventMask) {
select {
case c.ch <- struct{}{}:
default:
@@ -193,8 +196,8 @@ func (q *Queue) EventUnregister(e *Entry) {
func (q *Queue) Notify(mask EventMask) {
q.mu.RLock()
for e := q.list.Front(); e != nil; e = e.Next() {
- if mask&e.mask != 0 {
- e.Callback.Callback(e)
+ if m := mask & e.mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
q.mu.RUnlock()
diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go
index c1b94a4f3..6928f28b4 100644
--- a/pkg/waiter/waiter_test.go
+++ b/pkg/waiter/waiter_test.go
@@ -20,12 +20,12 @@ import (
)
type callbackStub struct {
- f func(e *Entry)
+ f func(e *Entry, m EventMask)
}
// Callback implements EntryCallback.Callback.
-func (c *callbackStub) Callback(e *Entry) {
- c.f(e)
+func (c *callbackStub) Callback(e *Entry, m EventMask) {
+ c.f(e, m)
}
func TestEmptyQueue(t *testing.T) {
@@ -36,7 +36,7 @@ func TestEmptyQueue(t *testing.T) {
// Register then unregister a waiter, then notify the queue.
cnt := 0
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn)
q.EventUnregister(&e)
q.Notify(EventIn)
@@ -49,7 +49,7 @@ func TestMask(t *testing.T) {
// Register a waiter.
var q Queue
var cnt int
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn|EventErr)
// Notify with an overlapping mask.
@@ -101,11 +101,14 @@ func TestConcurrentRegistration(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
cnt++
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
// Wait for notification, then register.
@@ -158,11 +161,14 @@ func TestConcurrentNotification(t *testing.T) {
// Register waiters.
for i := 0; i < waiterCount; i++ {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
atomic.AddInt32(&cnt, 1)
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
q.EventRegister(&e, EventIn|EventErr)