summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/elf.go50
-rw-r--r--pkg/abi/linux/epoll.go6
-rw-r--r--pkg/abi/linux/file.go5
-rw-r--r--pkg/abi/linux/fs.go3
-rw-r--r--pkg/abi/linux/netdevice.go4
-rw-r--r--pkg/abi/linux/netfilter.go47
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go14
-rw-r--r--pkg/abi/linux/netfilter_test.go5
-rw-r--r--pkg/abi/linux/netlink.go6
-rw-r--r--pkg/abi/linux/netlink_route.go6
-rw-r--r--pkg/abi/linux/ptrace_amd64.go5
-rw-r--r--pkg/abi/linux/ptrace_arm64.go5
-rw-r--r--pkg/abi/linux/socket.go16
-rw-r--r--pkg/bits/bits.go10
-rw-r--r--pkg/bpf/BUILD3
-rw-r--r--pkg/bpf/interpreter_test.go20
-rw-r--r--pkg/compressio/BUILD5
-rw-r--r--pkg/compressio/compressio.go41
-rw-r--r--pkg/coverage/coverage.go99
-rw-r--r--pkg/gohacks/BUILD10
-rw-r--r--pkg/gohacks/gohacks_test.go97
-rw-r--r--pkg/gohacks/gohacks_unsafe.go14
-rw-r--r--pkg/marshal/BUILD1
-rw-r--r--pkg/marshal/marshal.go3
-rw-r--r--pkg/marshal/primitive/primitive.go75
-rw-r--r--pkg/marshal/util.go (renamed from pkg/tcpip/transport/tcp/rack_state.go)20
-rw-r--r--pkg/merkletree/merkletree.go5
-rw-r--r--pkg/metric/metric.go193
-rw-r--r--pkg/metric/metric.proto11
-rw-r--r--pkg/metric/metric_test.go92
-rw-r--r--pkg/p9/client_file.go16
-rw-r--r--pkg/p9/file.go60
-rw-r--r--pkg/p9/handlers.go28
-rw-r--r--pkg/p9/messages.go84
-rw-r--r--pkg/p9/p9.go28
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/refs/refcounter.go10
-rw-r--r--pkg/refsvfs2/BUILD2
-rw-r--r--pkg/refsvfs2/refs_map.go25
-rw-r--r--pkg/ring0/kernel_amd64.go19
-rw-r--r--pkg/ring0/kernel_arm64.go8
-rw-r--r--pkg/ring0/lib_amd64.go6
-rw-r--r--pkg/ring0/lib_amd64.s12
-rw-r--r--pkg/ring0/lib_arm64.go3
-rw-r--r--pkg/ring0/lib_arm64.s8
-rw-r--r--pkg/ring0/pagetables/BUILD5
-rw-r--r--pkg/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/ring0/pagetables/pagetables_arm64_test.go38
-rw-r--r--pkg/safecopy/atomic_amd64.s48
-rw-r--r--pkg/safecopy/atomic_arm64.s36
-rw-r--r--pkg/safecopy/memclr_amd64.s6
-rw-r--r--pkg/safecopy/memclr_arm64.s6
-rw-r--r--pkg/safecopy/memcpy_amd64.s10
-rw-r--r--pkg/safecopy/memcpy_arm64.s10
-rw-r--r--pkg/safecopy/safecopy.go22
-rw-r--r--pkg/safecopy/safecopy_test.go62
-rw-r--r--pkg/safecopy/safecopy_unsafe.go12
-rw-r--r--pkg/safecopy/sighandler_amd64.s6
-rw-r--r--pkg/safecopy/sighandler_arm64.s6
-rw-r--r--pkg/safemem/BUILD1
-rw-r--r--pkg/safemem/block_unsafe.go19
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.go5
-rw-r--r--pkg/sentry/devices/memdev/zero.go1
-rw-r--r--pkg/sentry/fs/host/socket.go12
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD48
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go261
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go425
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpu.go70
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuacct.go114
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuset.go39
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/job.go64
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/memory.go74
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go360
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go250
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go10
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go386
-rw-r--r--pkg/sentry/fsimpl/host/host.go3
-rw-r--r--pkg/sentry/fsimpl/host/save_restore.go7
-rw-r--r--pkg/sentry/fsimpl/host/socket.go19
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go12
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go16
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go23
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go29
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go20
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go55
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go1
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go8
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD1
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go1
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go144
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go339
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go3
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/cgroup.go281
-rw-r--r--pkg/sentry/kernel/kernel.go52
-rw-r--r--pkg/sentry/kernel/task.go6
-rw-r--r--pkg/sentry/kernel/task_cgroup.go138
-rw-r--r--pkg/sentry/kernel/task_exit.go4
-rw-r--r--pkg/sentry/kernel/task_start.go5
-rw-r--r--pkg/sentry/kernel/task_syscall.go4
-rw-r--r--pkg/sentry/kernel/threads.go9
-rw-r--r--pkg/sentry/loader/BUILD3
-rw-r--r--pkg/sentry/loader/elf.go14
-rw-r--r--pkg/sentry/memmap/memmap.go5
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go13
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s12
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go37
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.s21
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/machine.go23
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go31
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go38
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go67
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s6
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s6
-rw-r--r--pkg/sentry/platform/ptrace/stub_unsafe.go9
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD4
-rw-r--r--pkg/sentry/socket/control/control.go66
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go25
-rw-r--r--pkg/sentry/socket/hostinet/stack.go29
-rw-r--r--pkg/sentry/socket/netfilter/BUILD4
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go13
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go7
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go7
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go17
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go9
-rw-r--r--pkg/sentry/socket/netfilter/targets.go214
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go8
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go8
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/message.go40
-rw-r--r--pkg/sentry/socket/netlink/message_test.go18
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go30
-rw-r--r--pkg/sentry/socket/netlink/socket.go21
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go58
-rw-r--r--pkg/sentry/socket/socket.go15
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go14
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go3
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless_state.go2
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go31
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/linux64_amd64.go1
-rw-r--r--pkg/sentry/strace/linux64_arm64.go1
-rw-r--r--pkg/sentry/strace/socket.go32
-rw-r--r--pkg/sentry/syscalls/epoll.go8
-rw-r--r--pkg/sentry/syscalls/linux/error.go15
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go56
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go31
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go52
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go31
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go2
-rw-r--r--pkg/sentry/time/BUILD1
-rw-r--r--pkg/sentry/time/calibrated_clock.go7
-rw-r--r--pkg/sentry/vfs/file_description.go14
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go3
-rw-r--r--pkg/sentry/vfs/mount.go20
-rw-r--r--pkg/sentry/vfs/opath.go4
-rw-r--r--pkg/sentry/vfs/resolving_path.go84
-rw-r--r--pkg/sentry/vfs/vfs.go83
-rw-r--r--pkg/shim/utils/volumes.go46
-rw-r--r--pkg/shim/utils/volumes_test.go160
-rw-r--r--pkg/state/statefile/BUILD1
-rw-r--r--pkg/state/statefile/statefile.go21
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/generic_seqatomic_unsafe.go3
-rw-r--r--pkg/sync/runtime_unsafe.go14
-rw-r--r--pkg/sync/seqatomictest/BUILD1
-rw-r--r--pkg/tcpip/BUILD34
-rw-r--r--pkg/tcpip/checker/checker.go20
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go20
-rw-r--r--pkg/tcpip/header/BUILD2
-rw-r--r--pkg/tcpip/header/eth_test.go3
-rw-r--r--pkg/tcpip/header/igmp_test.go6
-rw-r--r--pkg/tcpip/header/ipv4.go58
-rw-r--r--pkg/tcpip/header/ipv4_test.go75
-rw-r--r--pkg/tcpip/header/ipv6.go87
-rw-r--r--pkg/tcpip/header/ipv6_test.go104
-rw-r--r--pkg/tcpip/header/ndp_test.go11
-rw-r--r--pkg/tcpip/header/tcp.go53
-rw-r--r--pkg/tcpip/header/udp.go27
-rw-r--r--pkg/tcpip/link/channel/channel.go18
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go8
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go77
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go15
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/loopback/loopback.go4
-rw-r--r--pkg/tcpip/link/muxed/injectable.go8
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go16
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go8
-rw-r--r--pkg/tcpip/link/pipe/pipe.go4
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go33
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go24
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go24
-rw-r--r--pkg/tcpip/link/waitable/waitable.go8
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go10
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go8
-rw-r--r--pkg/tcpip/network/arp/arp_test.go18
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD1
-rw-r--r--pkg/tcpip/network/internal/ip/errors.go77
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go58
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go27
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go84
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go6
-rw-r--r--pkg/tcpip/network/ip_test.go114
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go38
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go14
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go215
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go312
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go84
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go198
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go277
-rw-r--r--pkg/tcpip/network/ipv6/mld.go24
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go157
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go142
-rw-r--r--pkg/tcpip/network/ipv6/stats.go4
-rw-r--r--pkg/tcpip/network/multicast_group_test.go30
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go26
-rw-r--r--pkg/tcpip/ports/ports_test.go36
-rw-r--r--pkg/tcpip/socketops.go58
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/conntrack.go236
-rw-r--r--pkg/tcpip/stack/forwarding_test.go23
-rw-r--r--pkg/tcpip/stack/hook_string.go41
-rw-r--r--pkg/tcpip/stack/iptables.go31
-rw-r--r--pkg/tcpip/stack/iptables_targets.go92
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go1156
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go2
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go8
-rw-r--r--pkg/tcpip/stack/nic.go34
-rw-r--r--pkg/tcpip/stack/nic_test.go4
-rw-r--r--pkg/tcpip/stack/packet_buffer.go20
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go140
-rw-r--r--pkg/tcpip/stack/pending_packets.go8
-rw-r--r--pkg/tcpip/stack/registration.go38
-rw-r--r--pkg/tcpip/stack/route.go20
-rw-r--r--pkg/tcpip/stack/stack.go313
-rw-r--r--pkg/tcpip/stack/stack_global_state.go72
-rw-r--r--pkg/tcpip/stack/stack_options.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go89
-rw-r--r--pkg/tcpip/stack/tcp.go451
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go17
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/stdclock.go130
-rw-r--r--pkg/tcpip/stdclock_state.go (renamed from pkg/tcpip/transport/tcp/cubic_state.go)21
-rw-r--r--pkg/tcpip/tcpip.go89
-rw-r--r--pkg/tcpip/tests/integration/BUILD6
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go194
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go4
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go6
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go21
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go16
-rw-r--r--pkg/tcpip/tests/integration/route_test.go5
-rw-r--r--pkg/tcpip/tests/utils/utils.go8
-rw-r--r--pkg/tcpip/testutil/BUILD18
-rw-r--r--pkg/tcpip/testutil/testutil.go43
-rw-r--r--pkg/tcpip/testutil/testutil_test.go103
-rw-r--r--pkg/tcpip/time_unsafe.go75
-rw-r--r--pkg/tcpip/timer_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go65
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go74
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go25
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go78
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go291
-rw-r--r--pkg/tcpip/transport/tcp/connect.go182
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go119
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go33
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go902
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go80
-rw-r--r--pkg/tcpip/transport/tcp/rack.go129
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go173
-rw-r--r--pkg/tcpip/transport/tcp/reno.go30
-rw-r--r--pkg/tcpip/transport/tcp/reno_recovery.go14
-rw-r--r--pkg/tcpip/transport/tcp/sack_recovery.go18
-rw-r--r--pkg/tcpip/transport/tcp/segment.go14
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go4
-rw-r--r--pkg/tcpip/transport/tcp/snd.go442
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go20
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go421
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go104
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go34
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go3
-rw-r--r--pkg/test/dockerutil/BUILD2
-rw-r--r--pkg/test/dockerutil/container.go9
-rw-r--r--pkg/test/dockerutil/profile.go11
323 files changed, 11180 insertions, 5169 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index ecaeb11ac..064a54547 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -76,7 +76,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/abi",
- "//pkg/binary",
"//pkg/bits",
"//pkg/marshal",
"//pkg/marshal/primitive",
@@ -88,7 +87,4 @@ go_test(
size = "small",
srcs = ["netfilter_test.go"],
library = ":linux",
- deps = [
- "//pkg/binary",
- ],
)
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
index 7c9a02f20..c5713541f 100644
--- a/pkg/abi/linux/elf.go
+++ b/pkg/abi/linux/elf.go
@@ -106,3 +106,53 @@ const (
// NT_ARM_TLS is for ARM TLS register.
NT_ARM_TLS = 0x401
)
+
+// ElfHeader64 is the ELF64 file header.
+//
+// +marshal
+type ElfHeader64 struct {
+ Ident [16]byte // File identification.
+ Type uint16 // File type.
+ Machine uint16 // Machine architecture.
+ Version uint32 // ELF format version.
+ Entry uint64 // Entry point.
+ Phoff uint64 // Program header file offset.
+ Shoff uint64 // Section header file offset.
+ Flags uint32 // Architecture-specific flags.
+ Ehsize uint16 // Size of ELF header in bytes.
+ Phentsize uint16 // Size of program header entry.
+ Phnum uint16 // Number of program header entries.
+ Shentsize uint16 // Size of section header entry.
+ Shnum uint16 // Number of section header entries.
+ Shstrndx uint16 // Section name strings section.
+}
+
+// ElfSection64 is the ELF64 Section header.
+//
+// +marshal
+type ElfSection64 struct {
+ Name uint32 // Section name (index into the section header string table).
+ Type uint32 // Section type.
+ Flags uint64 // Section flags.
+ Addr uint64 // Address in memory image.
+ Off uint64 // Offset in file.
+ Size uint64 // Size in bytes.
+ Link uint32 // Index of a related section.
+ Info uint32 // Depends on section type.
+ Addralign uint64 // Alignment in bytes.
+ Entsize uint64 // Size of each entry in section.
+}
+
+// ElfProg64 is the ELF64 Program header.
+//
+// +marshal
+type ElfProg64 struct {
+ Type uint32 // Entry type.
+ Flags uint32 // Access permission flags.
+ Off uint64 // File offset of contents.
+ Vaddr uint64 // Virtual address in memory image.
+ Paddr uint64 // Physical address (not used).
+ Filesz uint64 // Size of contents in file.
+ Memsz uint64 // Size of contents in memory.
+ Align uint64 // Alignment in memory and file.
+}
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 1121a1a92..67706f5aa 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,10 +14,6 @@
package linux
-import (
- "gvisor.dev/gvisor/pkg/binary"
-)
-
// Event masks.
const (
EPOLLIN = 0x1
@@ -59,4 +55,4 @@ const (
)
// SizeOfEpollEvent is the size of EpollEvent struct.
-var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
+var SizeOfEpollEvent = (*EpollEvent)(nil).SizeBytes()
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index e11ca2d62..1e23850a9 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -19,7 +19,6 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi"
- "gvisor.dev/gvisor/pkg/binary"
)
// Constants for open(2).
@@ -201,7 +200,7 @@ const (
)
// SizeOfStat is the size of a Stat struct.
-var SizeOfStat = binary.Size(Stat{})
+var SizeOfStat = (*Stat)(nil).SizeBytes()
// Flags for statx.
const (
@@ -268,7 +267,7 @@ type Statx struct {
}
// SizeOfStatx is the size of a Statx struct.
-var SizeOfStatx = binary.Size(Statx{})
+var SizeOfStatx = (*Statx)(nil).SizeBytes()
// FileMode represents a mode_t.
type FileMode uint16
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 0d921ed6f..cad24fcc7 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -19,8 +19,10 @@ package linux
// See linux/magic.h.
const (
ANON_INODE_FS_MAGIC = 0x09041934
+ CGROUP_SUPER_MAGIC = 0x27e0eb
DEVPTS_SUPER_MAGIC = 0x00001cd1
EXT_SUPER_MAGIC = 0xef53
+ FUSE_SUPER_MAGIC = 0x65735546
OVERLAYFS_SUPER_MAGIC = 0x794c7630
PIPEFS_MAGIC = 0x50495045
PROC_SUPER_MAGIC = 0x9fa0
@@ -29,7 +31,6 @@ const (
SYSFS_MAGIC = 0x62656572
TMPFS_MAGIC = 0x01021994
V9FS_MAGIC = 0x01021997
- FUSE_SUPER_MAGIC = 0x65735546
)
// Filesystem path limits, from uapi/linux/limits.h.
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
index 0faf015c7..51a39704b 100644
--- a/pkg/abi/linux/netdevice.go
+++ b/pkg/abi/linux/netdevice.go
@@ -14,8 +14,6 @@
package linux
-import "gvisor.dev/gvisor/pkg/binary"
-
const (
// IFNAMSIZ is the size of the name field for IFReq.
IFNAMSIZ = 16
@@ -66,7 +64,7 @@ func (ifr *IFReq) SetName(name string) {
}
// SizeOfIFReq is the binary size of an IFReq struct (40 bytes).
-var SizeOfIFReq = binary.Size(IFReq{})
+var SizeOfIFReq = (*IFReq)(nil).SizeBytes()
// IFMap contains interface hardware parameters.
type IFMap struct {
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 378f1baf3..3fd05483a 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -145,13 +145,13 @@ func (ke *KernelIPTEntry) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
- ke.Entry.MarshalBytes(dst)
+ ke.Entry.MarshalUnsafe(dst)
ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
}
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
- ke.Entry.UnmarshalBytes(src)
+ ke.Entry.UnmarshalUnsafe(src)
ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
@@ -245,6 +245,8 @@ const SizeOfXTCounters = 16
// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
// exposing different data to the user and kernel, but this struct holds only
// the user data.
+//
+// +marshal
type XTEntryMatch struct {
MatchSize uint16
Name ExtensionName
@@ -284,6 +286,8 @@ const SizeOfXTGetRevision = 30
// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
// exposing different data to the user and kernel, but this struct holds only
// the user data.
+//
+// +marshal
type XTEntryTarget struct {
TargetSize uint16
Name ExtensionName
@@ -306,6 +310,8 @@ type KernelXTEntryTarget struct {
// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTStandardTarget struct {
Target XTEntryTarget
// A positive verdict indicates a jump, and is the offset from the
@@ -322,6 +328,8 @@ const SizeOfXTStandardTarget = 40
// beginning of user-defined chains by putting the name of the chain in
// ErrorName. It corresponds to struct xt_error_target in
// include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTErrorTarget struct {
Target XTEntryTarget
Name ErrorName
@@ -349,6 +357,8 @@ const (
// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
// network byte order.
+//
+// +marshal
type NfNATIPV4Range struct {
Flags uint32
MinIP [4]byte
@@ -359,6 +369,8 @@ type NfNATIPV4Range struct {
// NfNATIPV4MultiRangeCompat corresponds to struct
// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+//
+// +marshal
type NfNATIPV4MultiRangeCompat struct {
RangeSize uint32
RangeIPV4 NfNATIPV4Range
@@ -366,6 +378,8 @@ type NfNATIPV4MultiRangeCompat struct {
// XTRedirectTarget triggers a redirect when reached.
// Adding 4 bytes of padding to make the struct 8 byte aligned.
+//
+// +marshal
type XTRedirectTarget struct {
Target XTEntryTarget
NfRange NfNATIPV4MultiRangeCompat
@@ -375,6 +389,19 @@ type XTRedirectTarget struct {
// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
const SizeOfXTRedirectTarget = 56
+// XTSNATTarget triggers Source NAT when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+//
+// +marshal
+type XTSNATTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTSNATTarget is the size of an XTSNATTarget.
+const SizeOfXTSNATTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
//
@@ -429,7 +456,7 @@ func (ke *KernelIPTGetEntries) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
- ke.IPTGetEntries.MarshalBytes(dst)
+ ke.IPTGetEntries.MarshalUnsafe(dst)
marshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
@@ -439,7 +466,7 @@ func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
- ke.IPTGetEntries.UnmarshalBytes(src)
+ ke.IPTGetEntries.UnmarshalUnsafe(src)
unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
@@ -452,6 +479,8 @@ var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTReplace struct {
Name TableName
ValidHooks uint32
@@ -491,6 +520,8 @@ func (tn TableName) String() string {
// ErrorName holds the name of a netfilter error. These can also hold
// user-defined chains.
+//
+// +marshal
type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
// String implements fmt.Stringer.
@@ -509,6 +540,8 @@ func goString(cstring []byte) string {
// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
// in include/uapi/linux/netfilter/xt_tcpudp.h.
+//
+// +marshal
type XTTCP struct {
// SourcePortStart specifies the inclusive start of the range of source
// ports to which the matcher applies.
@@ -562,6 +595,8 @@ const (
// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
// in include/uapi/linux/netfilter/xt_tcpudp.h.
+//
+// +marshal
type XTUDP struct {
// SourcePortStart is the inclusive start of the range of source ports
// to which the matcher applies.
@@ -602,6 +637,8 @@ const (
// IPTOwnerInfo holds data for matching packets with owner. It corresponds
// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+//
+// +marshal
type IPTOwnerInfo struct {
// UID is user id which created the packet.
UID uint32
@@ -623,7 +660,7 @@ type IPTOwnerInfo struct {
Match uint8
// Invert flips the meaning of Match field.
- Invert uint8
+ Invert uint8 `marshal:"unaligned"`
}
// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index b953e62dc..b088b207c 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -86,7 +86,7 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) {
- ke.IPTGetEntries.MarshalBytes(dst)
+ ke.IPTGetEntries.MarshalUnsafe(dst)
marshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
@@ -96,7 +96,7 @@ func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) {
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) {
- ke.IPTGetEntries.UnmarshalBytes(src)
+ ke.IPTGetEntries.UnmarshalUnsafe(src)
unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
@@ -149,8 +149,8 @@ type IP6TEntry struct {
const SizeOfIP6TEntry = 168
// KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field.
-// KernelIP6TEntry itself is not Marshallable but it implements some methods of
-// marshal.Marshallable that help in other implementations of Marshallable.
+//
+// +marshal dynamic
type KernelIP6TEntry struct {
Entry IP6TEntry
@@ -168,13 +168,13 @@ func (ke *KernelIP6TEntry) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) {
- ke.Entry.MarshalBytes(dst)
+ ke.Entry.MarshalUnsafe(dst)
ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
}
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) {
- ke.Entry.UnmarshalBytes(src)
+ ke.Entry.UnmarshalUnsafe(src)
ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
@@ -264,6 +264,8 @@ const (
// NFNATRange corresponds to struct nf_nat_range in
// include/uapi/linux/netfilter/nf_nat.h.
+//
+// +marshal
type NFNATRange struct {
Flags uint32
MinAddr Inet6Addr
diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go
index bf73271c6..600820a0b 100644
--- a/pkg/abi/linux/netfilter_test.go
+++ b/pkg/abi/linux/netfilter_test.go
@@ -15,9 +15,8 @@
package linux
import (
+ "encoding/binary"
"testing"
-
- "gvisor.dev/gvisor/pkg/binary"
)
func TestSizes(t *testing.T) {
@@ -42,7 +41,7 @@ func TestSizes(t *testing.T) {
}
for _, tc := range testCases {
- if calculated := binary.Size(tc.typ); calculated != tc.defined {
+ if calculated := uintptr(binary.Size(tc.typ)); calculated != tc.defined {
t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated)
}
}
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
index b41f94a69..232fee67e 100644
--- a/pkg/abi/linux/netlink.go
+++ b/pkg/abi/linux/netlink.go
@@ -53,6 +53,8 @@ type SockAddrNetlink struct {
const SockAddrNetlinkSize = 12
// NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkMessageHeader struct {
Length uint32
Type uint16
@@ -99,6 +101,8 @@ const NLMSG_ALIGNTO = 4
// NetlinkAttrHeader is the header of a netlink attribute, followed by payload.
//
// This is struct nlattr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkAttrHeader struct {
Length uint16
Type uint16
@@ -126,6 +130,8 @@ const (
)
// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkErrorMessage struct {
Error int32
Header NetlinkMessageHeader
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index ceda0a8d3..581a11b24 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -85,6 +85,8 @@ const (
)
// InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h.
+//
+// +marshal
type InterfaceInfoMessage struct {
Family uint8
_ uint8
@@ -164,6 +166,8 @@ const (
)
// InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h.
+//
+// +marshal
type InterfaceAddrMessage struct {
Family uint8
PrefixLen uint8
@@ -193,6 +197,8 @@ const (
)
// RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h.
+//
+// +marshal
type RouteMessage struct {
Family uint8
DstLen uint8
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
index 50e22fe7e..e722971f1 100644
--- a/pkg/abi/linux/ptrace_amd64.go
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Rsp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Rsp = sp
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
index da36811d2..3d0906565 100644
--- a/pkg/abi/linux/ptrace_arm64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Sp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Sp = sp
+}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 185eee0bb..95871b8a5 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -15,7 +15,6 @@
package linux
import (
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal"
)
@@ -251,18 +250,24 @@ type SockAddrInet struct {
}
// Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h.
+//
+// +marshal
type Inet6MulticastRequest struct {
MulticastAddr Inet6Addr
InterfaceIndex int32
}
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
+//
+// +marshal
type InetMulticastRequest struct {
MulticastAddr InetAddr
InterfaceAddr InetAddr
}
// InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h.
+//
+// +marshal
type InetMulticastRequestWithNIC struct {
InetMulticastRequest
InterfaceIndex int32
@@ -491,7 +496,7 @@ type TCPInfo struct {
}
// SizeOfTCPInfo is the binary size of a TCPInfo struct.
-var SizeOfTCPInfo = int(binary.Size(TCPInfo{}))
+var SizeOfTCPInfo = (*TCPInfo)(nil).SizeBytes()
// Control message types, from linux/socket.h.
const (
@@ -502,6 +507,8 @@ const (
// A ControlMessageHeader is the header for a socket control message.
//
// ControlMessageHeader represents struct cmsghdr from linux/socket.h.
+//
+// +marshal
type ControlMessageHeader struct {
Length uint64
Level int32
@@ -510,7 +517,7 @@ type ControlMessageHeader struct {
// SizeOfControlMessageHeader is the binary size of a ControlMessageHeader
// struct.
-var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
+var SizeOfControlMessageHeader = (*ControlMessageHeader)(nil).SizeBytes()
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
@@ -527,6 +534,7 @@ type ControlMessageCredentials struct {
//
// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
//
+// +marshal
// +stateify savable
type ControlMessageIPPacketInfo struct {
NIC int32
@@ -536,7 +544,7 @@ type ControlMessageIPPacketInfo struct {
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
-var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
+var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes()
// A ControlMessageRights is an SCM_RIGHTS socket control message.
type ControlMessageRights []int32
diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go
index a26433ad6..d16448c3d 100644
--- a/pkg/bits/bits.go
+++ b/pkg/bits/bits.go
@@ -14,3 +14,13 @@
// Package bits includes all bit related types and operations.
package bits
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index 2a6977f85..c17390522 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -26,6 +26,7 @@ go_test(
library = ":bpf",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/hostarch",
+ "//pkg/marshal",
],
)
diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go
index c85d786b9..f64a2dc50 100644
--- a/pkg/bpf/interpreter_test.go
+++ b/pkg/bpf/interpreter_test.go
@@ -15,10 +15,12 @@
package bpf
import (
+ "encoding/binary"
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
)
func TestCompilationErrors(t *testing.T) {
@@ -750,29 +752,29 @@ func TestSimpleFilter(t *testing.T) {
// desc is the test's description.
desc string
- // seccompData is the input data.
- seccompData
+ // SeccompData is the input data.
+ data linux.SeccompData
// expectedRet is the expected return value of the BPF program.
expectedRet uint32
}{
{
desc: "Invalid arch is rejected",
- seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */},
+ data: linux.SeccompData{Nr: 1 /* x86 exit */, Arch: 0x40000003 /* AUDIT_ARCH_I386 */},
expectedRet: 0,
},
{
desc: "Disallowed syscall is rejected",
- seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e},
+ data: linux.SeccompData{Nr: 105 /* __NR_setuid */, Arch: 0xc000003e},
expectedRet: 0,
},
{
desc: "Allowed syscall is indeed allowed",
- seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e},
+ data: linux.SeccompData{Nr: 231 /* __NR_exit_group */, Arch: 0xc000003e},
expectedRet: 0x7fff0000,
},
} {
- ret, err := Exec(p, test.seccompData.asInput())
+ ret, err := Exec(p, dataAsInput(&test.data))
if err != nil {
t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err)
continue
@@ -792,6 +794,6 @@ type seccompData struct {
}
// asInput converts a seccompData to a bpf.Input.
-func (d *seccompData) asInput() Input {
- return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+func dataAsInput(data *linux.SeccompData) Input {
+ return InputBytes{marshal.Marshal(data), hostarch.ByteOrder}
}
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index 1f75319a7..70018cf18 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -6,10 +6,7 @@ go_library(
name = "compressio",
srcs = ["compressio.go"],
visibility = ["//:sandbox"],
- deps = [
- "//pkg/binary",
- "//pkg/sync",
- ],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index b094c5662..615d7f134 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -48,12 +48,12 @@ import (
"compress/flate"
"crypto/hmac"
"crypto/sha256"
+ "encoding/binary"
"errors"
"hash"
"io"
"runtime"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -130,6 +130,10 @@ type worker struct {
hashPool *hashPool
input chan *chunk
output chan result
+
+ // scratch is a temporary buffer used for marshalling. This is declared
+ // unfront here to avoid reallocation.
+ scratch [4]byte
}
// work is the main work routine; see worker.
@@ -167,7 +171,8 @@ func (w *worker) work(compress bool, level int) {
// Write the hash, if enabled.
if h != nil {
- binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len()))
+ h.Write(w.scratch[:4])
c.h = h
h = nil
}
@@ -175,7 +180,8 @@ func (w *worker) work(compress bool, level int) {
// Check the hash of the compressed contents.
if h != nil {
h.Write(c.compressed.Bytes())
- binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len()))
+ h.Write(w.scratch[:4])
io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
sum := h.Sum(nil)
@@ -352,6 +358,10 @@ type Reader struct {
// in is the source.
in io.Reader
+
+ // scratch is a temporary buffer used for marshalling. This is declared
+ // unfront here to avoid reallocation.
+ scratch [4]byte
}
var _ io.Reader = (*Reader)(nil)
@@ -368,14 +378,15 @@ func NewReader(in io.Reader, key []byte) (*Reader, error) {
// Use double buffering for read.
r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
- var err error
- if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
+ if _, err := io.ReadFull(in, r.scratch[:4]); err != nil {
return nil, err
}
+ r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4])
if r.hashPool != nil {
h := r.hashPool.getHash()
- binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize)
+ h.Write(r.scratch[:4])
r.lastSum = h.Sum(nil)
r.hashPool.putHash(h)
sum := make([]byte, len(r.lastSum))
@@ -467,8 +478,7 @@ func (r *Reader) Read(p []byte) (int, error) {
// reader. The length is used to limit the reader.
//
// See writer.flush.
- l, err := binary.ReadUint32(r.in, binary.BigEndian)
- if err != nil {
+ if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil {
// This is generally okay as long as there
// are still buffers outstanding. We actually
// just wait for completion of those buffers here
@@ -488,6 +498,7 @@ func (r *Reader) Read(p []byte) (int, error) {
return done, err
}
}
+ l := binary.BigEndian.Uint32(r.scratch[:4])
// Read this chunk and schedule decompression.
compressed := bufPool.Get().(*bytes.Buffer)
@@ -573,6 +584,10 @@ type Writer struct {
// closed indicates whether the file has been closed.
closed bool
+
+ // scratch is a temporary buffer used for marshalling. This is declared
+ // unfront here to avoid reallocation.
+ scratch [4]byte
}
var _ io.Writer = (*Writer)(nil)
@@ -594,13 +609,15 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer,
}
w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
- if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
+ binary.BigEndian.PutUint32(w.scratch[:], chunkSize)
+ if _, err := w.out.Write(w.scratch[:4]); err != nil {
return nil, err
}
if w.hashPool != nil {
h := w.hashPool.getHash()
- binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ binary.BigEndian.PutUint32(w.scratch[:], chunkSize)
+ h.Write(w.scratch[:4])
w.lastSum = h.Sum(nil)
w.hashPool.putHash(h)
if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
@@ -616,7 +633,9 @@ func (w *Writer) flush(c *chunk) error {
// Prefix each chunk with a length; this allows the reader to safely
// limit reads while buffering.
l := uint32(c.compressed.Len())
- if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil {
+
+ binary.BigEndian.PutUint32(w.scratch[:], l)
+ if _, err := w.out.Write(w.scratch[:4]); err != nil {
return err
}
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index a6778a005..b33a20802 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -26,6 +26,7 @@ import (
"fmt"
"io"
"sort"
+ "sync/atomic"
"testing"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -34,12 +35,16 @@ import (
"github.com/bazelbuild/rules_go/go/tools/coverdata"
)
-// coverageMu must be held while accessing coverdata.Cover. This prevents
-// concurrent reads/writes from multiple threads collecting coverage data.
-var coverageMu sync.RWMutex
+var (
+ // coverageMu must be held while accessing coverdata.Cover. This prevents
+ // concurrent reads/writes from multiple threads collecting coverage data.
+ coverageMu sync.RWMutex
-// once ensures that globalData is only initialized once.
-var once sync.Once
+ // reportOutput is the place to write out a coverage report. It should be
+ // closed after the report is written. It is protected by reportOutputMu.
+ reportOutput io.WriteCloser
+ reportOutputMu sync.Mutex
+)
// blockBitLength is the number of bits used to represent coverage block index
// in a synthetic PC (the rest are used to represent the file index). Even
@@ -51,12 +56,26 @@ var once sync.Once
// file and every block.
const blockBitLength = 16
-// KcovAvailable returns whether the kcov coverage interface is available. It is
-// available as long as coverage is enabled for some files.
-func KcovAvailable() bool {
+// Available returns whether any coverage data is available.
+func Available() bool {
return len(coverdata.Cover.Blocks) > 0
}
+// EnableReport sets up coverage reporting.
+func EnableReport(w io.WriteCloser) {
+ reportOutputMu.Lock()
+ defer reportOutputMu.Unlock()
+ reportOutput = w
+}
+
+// KcovSupported returns whether the kcov interface should be made available.
+//
+// If coverage reporting is on, do not turn on kcov, which will consume
+// coverage data.
+func KcovSupported() bool {
+ return (reportOutput == nil) && Available()
+}
+
var globalData struct {
// files is the set of covered files sorted by filename. It is calculated at
// startup.
@@ -65,6 +84,9 @@ var globalData struct {
// syntheticPCs are a set of PCs calculated at startup, where the PC
// at syntheticPCs[i][j] corresponds to file i, block j.
syntheticPCs [][]uint64
+
+ // once ensures that globalData is only initialized once.
+ once sync.Once
}
// ClearCoverageData clears existing coverage data.
@@ -166,7 +188,7 @@ func ConsumeCoverageData(w io.Writer) int {
// InitCoverageData initializes globalData. It should be called before any kcov
// data is written.
func InitCoverageData() {
- once.Do(func() {
+ globalData.once.Do(func() {
// First, order all files. Then calculate synthetic PCs for every block
// (using the well-defined ordering for files as well).
for file := range coverdata.Cover.Blocks {
@@ -185,6 +207,38 @@ func InitCoverageData() {
})
}
+// reportOnce ensures that a coverage report is written at most once. For a
+// complete coverage report, Report should be called during the sandbox teardown
+// process. Report is called from multiple places (which may overlap) so that a
+// coverage report is written in different sandbox exit scenarios.
+var reportOnce sync.Once
+
+// Report writes out a coverage report with all blocks that have been covered.
+//
+// TODO(b/144576401): Decide whether this should actually be in LCOV format
+func Report() error {
+ if reportOutput == nil {
+ return nil
+ }
+
+ var err error
+ reportOnce.Do(func() {
+ for file, counters := range coverdata.Cover.Counters {
+ blocks := coverdata.Cover.Blocks[file]
+ for i := 0; i < len(counters); i++ {
+ if atomic.LoadUint32(&counters[i]) > 0 {
+ err = writeBlock(reportOutput, file, blocks[i])
+ if err != nil {
+ return
+ }
+ }
+ }
+ }
+ reportOutput.Close()
+ })
+ return err
+}
+
// Symbolize prints information about the block corresponding to pc.
func Symbolize(out io.Writer, pc uint64) error {
fileNum, blockNum := syntheticPCToIndexes(pc)
@@ -196,18 +250,32 @@ func Symbolize(out io.Writer, pc uint64) error {
if err != nil {
return err
}
- writeBlock(out, pc, file, block)
- return nil
+ return writeBlockWithPC(out, pc, file, block)
}
// WriteAllBlocks prints all information about all blocks along with their
// corresponding synthetic PCs.
-func WriteAllBlocks(out io.Writer) {
+func WriteAllBlocks(out io.Writer) error {
for fileNum, file := range globalData.files {
for blockNum, block := range coverdata.Cover.Blocks[file] {
- writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block)
+ if err := writeBlockWithPC(out, calculateSyntheticPC(fileNum, blockNum), file, block); err != nil {
+ return err
+ }
}
}
+ return nil
+}
+
+func writeBlockWithPC(out io.Writer, pc uint64, file string, block testing.CoverBlock) error {
+ if _, err := io.WriteString(out, fmt.Sprintf("%#x\n", pc)); err != nil {
+ return err
+ }
+ return writeBlock(out, file, block)
+}
+
+func writeBlock(out io.Writer, file string, block testing.CoverBlock) error {
+ _, err := io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
+ return err
}
func calculateSyntheticPC(fileNum int, blockNum int) uint64 {
@@ -239,8 +307,3 @@ func blockFromIndex(file string, i int) (testing.CoverBlock, error) {
}
return blocks[i], nil
}
-
-func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) {
- io.WriteString(out, fmt.Sprintf("%#x\n", pc))
- io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
-}
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
index 35683fe98..b4e05f922 100644
--- a/pkg/gohacks/BUILD
+++ b/pkg/gohacks/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,3 +10,11 @@ go_library(
stateify = False,
visibility = ["//:sandbox"],
)
+
+go_test(
+ name = "gohacks_test",
+ size = "small",
+ srcs = ["gohacks_test.go"],
+ library = ":gohacks",
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/gohacks/gohacks_test.go b/pkg/gohacks/gohacks_test.go
new file mode 100644
index 000000000..e18c8abc7
--- /dev/null
+++ b/pkg/gohacks/gohacks_test.go
@@ -0,0 +1,97 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gohacks
+
+import (
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "runtime/debug"
+ "testing"
+
+ "golang.org/x/sys/unix"
+)
+
+func randBuf(size int) []byte {
+ b := make([]byte, size)
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular
+// dependency.
+const pageSize = 4096
+
+func testCopy(dst, src []byte) (panicked bool) {
+ defer func() {
+ if r := recover(); r != nil {
+ panicked = true
+ }
+ }()
+ debug.SetPanicOnFault(true)
+ copy(dst, src)
+ return panicked
+}
+
+func TestSegVOnMemmove(t *testing.T) {
+ // Test that SIGSEGVs received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer unix.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
+
+func TestSigbusOnMemmove(t *testing.T) {
+ // Test that SIGBUS received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ f, err := ioutil.TempFile("", "sigbus_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ os.Remove(f.Name())
+ defer f.Close()
+
+ a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer unix.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
index 10bbb1f58..374aac2b4 100644
--- a/pkg/gohacks/gohacks_unsafe.go
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -75,3 +75,17 @@ func StringFromImmutableBytes(bs []byte) string {
// strings.Builder.String().
return *(*string)(unsafe.Pointer(&bs))
}
+
+// Note that go:linkname silently doesn't work if the local name is exported,
+// necessitating an indirection for exported functions.
+
+// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
+//
+//go:nosplit
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD
index 7cd89e639..7a5002176 100644
--- a/pkg/marshal/BUILD
+++ b/pkg/marshal/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"marshal.go",
"marshal_impl_util.go",
+ "util.go",
],
visibility = [
"//:sandbox",
diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go
index eb036feae..7da450ce8 100644
--- a/pkg/marshal/marshal.go
+++ b/pkg/marshal/marshal.go
@@ -166,6 +166,9 @@ type Marshallable interface {
// %s is the first argument to the slice clause. This directive is not supported
// for newtypes on arrays.
//
+// Note: Partial copies are not supported for Slice API UnmarshalUnsafe and
+// MarshalUnsafe.
+//
// The slice clause also takes an optional second argument, which must be the
// value "inner":
//
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index 32c8ed138..6f38992b7 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -125,6 +125,81 @@ func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) {
var _ marshal.Marshallable = (*ByteSlice)(nil)
+// The following set of functions are convenient shorthands for wrapping a
+// built-in type in a marshallable primitive type. For example:
+//
+// func useMarshallable(m marshal.Marshallable) { ... }
+//
+// // Compare:
+//
+// buf = []byte{...}
+// // useMarshallable(&primitive.ByteSlice(buf)) // Not allowed, can't address temp value.
+// bufP := primitive.ByteSlice(buf)
+// useMarshallable(&bufP)
+//
+// // Vs:
+//
+// useMarshallable(AsByteSlice(buf))
+//
+// Note that the argument to these function escapes, so avoid using them on very
+// hot code paths. But generally if a function accepts an interface as an
+// argument, the argument escapes anyways.
+
+// AllocateInt8 returns x as a marshallable.
+func AllocateInt8(x int8) marshal.Marshallable {
+ p := Int8(x)
+ return &p
+}
+
+// AllocateUint8 returns x as a marshallable.
+func AllocateUint8(x uint8) marshal.Marshallable {
+ p := Uint8(x)
+ return &p
+}
+
+// AllocateInt16 returns x as a marshallable.
+func AllocateInt16(x int16) marshal.Marshallable {
+ p := Int16(x)
+ return &p
+}
+
+// AllocateUint16 returns x as a marshallable.
+func AllocateUint16(x uint16) marshal.Marshallable {
+ p := Uint16(x)
+ return &p
+}
+
+// AllocateInt32 returns x as a marshallable.
+func AllocateInt32(x int32) marshal.Marshallable {
+ p := Int32(x)
+ return &p
+}
+
+// AllocateUint32 returns x as a marshallable.
+func AllocateUint32(x uint32) marshal.Marshallable {
+ p := Uint32(x)
+ return &p
+}
+
+// AllocateInt64 returns x as a marshallable.
+func AllocateInt64(x int64) marshal.Marshallable {
+ p := Int64(x)
+ return &p
+}
+
+// AllocateUint64 returns x as a marshallable.
+func AllocateUint64(x uint64) marshal.Marshallable {
+ p := Uint64(x)
+ return &p
+}
+
+// AsByteSlice returns b as a marshallable. Note that this allocates a new slice
+// header, but does not copy the slice contents.
+func AsByteSlice(b []byte) marshal.Marshallable {
+ bs := ByteSlice(b)
+ return &bs
+}
+
// Below, we define some convenience functions for marshalling primitive types
// using the newtypes above, without requiring superfluous casts.
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/marshal/util.go
index c9dc7e773..c1e5475bd 100644
--- a/pkg/tcpip/transport/tcp/rack_state.go
+++ b/pkg/marshal/util.go
@@ -12,18 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+package marshal
-import (
- "time"
-)
-
-// saveXmitTime is invoked by stateify.
-func (rc *rackControl) saveXmitTime() unixTime {
- return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
-}
-
-// loadXmitTime is invoked by stateify.
-func (rc *rackControl) loadXmitTime(unix unixTime) {
- rc.xmitTime = time.Unix(unix.second, unix.nano)
+// Marshal returns the serialized contents of m in a newly allocated
+// byte slice.
+func Marshal(m Marshallable) []byte {
+ buf := make([]byte, m.SizeBytes())
+ m.MarshalUnsafe(buf)
+ return buf
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 961bd4dcf..ac7868ad9 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -36,7 +36,6 @@ const (
)
// DigestSize returns the size (in bytes) of a digest.
-// TODO(b/156980949): Allow config SHA384.
func DigestSize(hashAlgorithm int) int {
switch hashAlgorithm {
case linux.FS_VERITY_HASH_ALG_SHA256:
@@ -69,7 +68,6 @@ func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool)
blockSize: hostarch.PageSize,
}
- // TODO(b/156980949): Allow config SHA384.
switch hashAlgorithms {
case linux.FS_VERITY_HASH_ALG_SHA256:
layout.digestSize = sha256DigestSize
@@ -238,6 +236,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
+ Children: params.Children,
SymlinkTarget: params.SymlinkTarget,
}
@@ -428,8 +427,6 @@ func Verify(params *VerifyParams) (int64, error) {
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
- // TODO(b/162908070): Investigate possible issues with zero
- // padding the data.
if bytesRead < len(buf) {
for j := bytesRead; j < len(buf); j++ {
buf[j] = 0
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index c9f9357de..e822fe77d 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -35,10 +35,15 @@ var (
// ErrInitializationDone indicates that the caller tried to create a
// new metric after initialization.
ErrInitializationDone = errors.New("metric cannot be created after initialization is complete")
+
+ // WeirdnessMetric is a metric with fields created to track the number
+ // of weird occurrences such as time fallback, partial_result and
+ // vsyscall count.
+ WeirdnessMetric *Uint64Metric
)
// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
-// monitored.
+// monitored. We currently support metrics with at most one field.
//
// Metrics are not saved across save/restore and thus reset to zero on restore.
//
@@ -46,6 +51,16 @@ var (
type Uint64Metric struct {
// value is the actual value of the metric. It must be accessed atomically.
value uint64
+
+ // numFields is the number of metric fields. It is immutable once
+ // initialized.
+ numFields int
+
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+
+ // fields is the map of fields in the metric.
+ fields map[string]uint64
}
var (
@@ -97,8 +112,19 @@ type customUint64Metric struct {
// metadata describes the metric. It is immutable.
metadata *pb.MetricMetadata
- // value returns the current value of the metric.
- value func() uint64
+ // value returns the current value of the metric for the given set of
+ // fields. It takes a variadic number of field values as argument.
+ value func(fieldValues ...string) uint64
+}
+
+// Field contains the field name and allowed values for the metric which is
+// used in registration of the metric.
+type Field struct {
+ // name is the metric field name.
+ name string
+
+ // allowedValues is the list of allowed values for the field.
+ allowedValues []string
}
// RegisterCustomUint64Metric registers a metric with the given name.
@@ -109,7 +135,8 @@ type customUint64Metric struct {
// Preconditions:
// * name must be globally unique.
// * Initialize/Disable have not been called.
-func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error {
+// * value is expected to accept exactly len(fields) arguments.
+func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func(...string) uint64, fields ...Field) error {
if initialized {
return ErrInitializationDone
}
@@ -129,13 +156,25 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
},
value: value,
}
+
+ // Metrics can exist without fields.
+ if len(fields) > 1 {
+ panic("Sentry metrics support at most one field")
+ }
+
+ for _, field := range fields {
+ allMetrics.m[name].metadata.Fields = append(allMetrics.m[name].metadata.Fields, &pb.MetricMetadata_Field{
+ FieldName: field.name,
+ AllowedValues: field.allowedValues,
+ })
+ }
return nil
}
-// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics
-// if it returns an error.
-func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) {
- if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil {
+// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric for metrics
+// without fields and panics if it returns an error.
+func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) {
+ if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil {
panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
}
}
@@ -144,15 +183,24 @@ func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, descript
// name.
//
// Metrics must be statically defined (i.e., at init).
-func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) {
- var m Uint64Metric
- return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value)
+func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string, fields ...Field) (*Uint64Metric, error) {
+ m := Uint64Metric{
+ numFields: len(fields),
+ }
+
+ if m.numFields == 1 {
+ m.fields = make(map[string]uint64)
+ for _, fieldValue := range fields[0].allowedValues {
+ m.fields[fieldValue] = 0
+ }
+ }
+ return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value, fields...)
}
// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an
// error.
-func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric {
- m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description)
+func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...)
if err != nil {
panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
}
@@ -169,19 +217,56 @@ func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description st
return m
}
-// Value returns the current value of the metric.
-func (m *Uint64Metric) Value() uint64 {
- return atomic.LoadUint64(&m.value)
+// Value returns the current value of the metric for the given set of fields.
+func (m *Uint64Metric) Value(fieldValues ...string) uint64 {
+ if m.numFields != len(fieldValues) {
+ panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields))
+ }
+
+ switch m.numFields {
+ case 0:
+ return atomic.LoadUint64(&m.value)
+ case 1:
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ fieldValue := fieldValues[0]
+ if _, ok := m.fields[fieldValue]; !ok {
+ panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue))
+ }
+ return m.fields[fieldValue]
+ default:
+ panic("Sentry metrics do not support more than one field")
+ }
}
-// Increment increments the metric by 1.
-func (m *Uint64Metric) Increment() {
- atomic.AddUint64(&m.value, 1)
+// Increment increments the metric field by 1.
+func (m *Uint64Metric) Increment(fieldValues ...string) {
+ m.IncrementBy(1, fieldValues...)
}
// IncrementBy increments the metric by v.
-func (m *Uint64Metric) IncrementBy(v uint64) {
- atomic.AddUint64(&m.value, v)
+func (m *Uint64Metric) IncrementBy(v uint64, fieldValues ...string) {
+ if m.numFields != len(fieldValues) {
+ panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields))
+ }
+
+ switch m.numFields {
+ case 0:
+ atomic.AddUint64(&m.value, v)
+ return
+ case 1:
+ fieldValue := fieldValues[0]
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if _, ok := m.fields[fieldValue]; !ok {
+ panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue))
+ }
+ m.fields[fieldValue] += v
+ default:
+ panic("Sentry metrics do not support more than one field")
+ }
}
// metricSet holds named metrics.
@@ -199,14 +284,30 @@ func makeMetricSet() metricSet {
// Values returns a snapshot of all values in m.
func (m *metricSet) Values() metricValues {
vals := make(metricValues)
+
for k, v := range m.m {
- vals[k] = v.value()
+ fields := v.metadata.GetFields()
+ switch len(fields) {
+ case 0:
+ vals[k] = v.value()
+ case 1:
+ values := fields[0].GetAllowedValues()
+ fieldsMap := make(map[string]uint64)
+ for _, fieldValue := range values {
+ fieldsMap[fieldValue] = v.value(fieldValue)
+ }
+ vals[k] = fieldsMap
+ default:
+ panic(fmt.Sprintf("Unsupported number of metric fields: %d", len(fields)))
+ }
}
return vals
}
-// metricValues contains a copy of the values of all metrics.
-type metricValues map[string]uint64
+// metricValues contains a copy of the values of all metrics. It is a map
+// with key as metric name and value can be either uint64 or map[string]uint64
+// to support metrics with one field.
+type metricValues map[string]interface{}
var (
// emitMu protects metricsAtLastEmit and ensures that all emitted
@@ -233,14 +334,37 @@ func EmitMetricUpdate() {
snapshot := allMetrics.Values()
m := pb.MetricUpdate{}
+ // On the first call metricsAtLastEmit will be empty. Include all
+ // metrics then.
for k, v := range snapshot {
- // On the first call metricsAtLastEmit will be empty. Include
- // all metrics then.
- if prev, ok := metricsAtLastEmit[k]; !ok || prev != v {
+ prev, ok := metricsAtLastEmit[k]
+ switch t := v.(type) {
+ case uint64:
+ // Metric exists and value did not change.
+ if ok && prev.(uint64) == t {
+ continue
+ }
+
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
- Value: &pb.MetricValue_Uint64Value{v},
+ Value: &pb.MetricValue_Uint64Value{t},
})
+ case map[string]uint64:
+ for fieldValue, metricValue := range t {
+ // Emit data on the first call only if the field
+ // value has been incremented. For all other
+ // calls, emit data if the field value has been
+ // changed from the previous emit.
+ if (!ok && metricValue == 0) || (ok && prev.(map[string]uint64)[fieldValue] == metricValue) {
+ continue
+ }
+
+ m.Metrics = append(m.Metrics, &pb.MetricValue{
+ Name: k,
+ FieldValues: []string{fieldValue},
+ Value: &pb.MetricValue_Uint64Value{metricValue},
+ })
+ }
}
}
@@ -261,3 +385,16 @@ func EmitMetricUpdate() {
eventchannel.Emit(&m)
}
+
+// CreateSentryMetrics creates the sentry metrics during kernel initialization.
+func CreateSentryMetrics() {
+ if WeirdnessMetric != nil {
+ return
+ }
+
+ WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result and vsyscalls invoked in the sandbox",
+ Field{
+ name: "weirdness_type",
+ allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count"},
+ })
+}
diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto
index 3cc89047d..53c8b4b50 100644
--- a/pkg/metric/metric.proto
+++ b/pkg/metric/metric.proto
@@ -48,6 +48,15 @@ message MetricMetadata {
// units is the units of the metric value.
Units units = 6;
+
+ message Field {
+ string field_name = 1;
+ repeated string allowed_values = 2;
+ }
+
+ // fields contains the metric fields. Currently a metric can have at most
+ // one field.
+ repeated Field fields = 7;
}
// MetricRegistration contains the metadata for all metrics that will be in
@@ -66,6 +75,8 @@ message MetricValue {
oneof value {
uint64 uint64_value = 2;
}
+
+ repeated string field_values = 4;
}
// MetricUpdate contains new values for multiple distinct metrics.
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index aefd0ea5c..c71dfd460 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -59,8 +59,9 @@ func reset() {
}
const (
- fooDescription = "Foo!"
- barDescription = "Bar Baz"
+ fooDescription = "Foo!"
+ barDescription = "Bar Baz"
+ counterDescription = "Counter"
)
func TestInitialize(t *testing.T) {
@@ -95,7 +96,7 @@ func TestInitialize(t *testing.T) {
foundBar := false
for _, m := range mr.Metrics {
if m.Type != pb.MetricMetadata_TYPE_UINT64 {
- t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64)
+ t.Errorf("Metadata %+v Type got %v want pb.MetricMetadata_TYPE_UINT64", m, m.Type)
}
if !m.Cumulative {
t.Errorf("Metadata %+v Cumulative got false want true", m)
@@ -256,3 +257,88 @@ func TestEmitMetricUpdate(t *testing.T) {
t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value)
}
}
+
+func TestEmitMetricUpdateWithFields(t *testing.T) {
+ defer reset()
+
+ field := Field{
+ name: "weirdness_type",
+ allowedValues: []string{"weird1", "weird2"}}
+
+ counter, err := NewUint64Metric("/weirdness", false, pb.MetricMetadata_UNITS_NONE, counterDescription, field)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ Initialize()
+
+ // Don't care about the registration metrics.
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ // For metrics with fields, we do not emit data unless the value is
+ // incremented.
+ if len(emitter) != 0 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 0", len(emitter))
+ }
+
+ counter.IncrementBy(4, "weird1")
+ counter.Increment("weird2")
+
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ if len(emitter) != 1 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter))
+ }
+
+ update, ok := emitter[0].(*pb.MetricUpdate)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0])
+ }
+
+ if len(update.Metrics) != 2 {
+ t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics))
+ }
+
+ foundWeird1 := false
+ foundWeird2 := false
+ for i := 0; i < len(update.Metrics); i++ {
+ m := update.Metrics[i]
+
+ if m.Name != "/weirdness" {
+ t.Errorf("Metric %+v name got %q want '/weirdness'", m, m.Name)
+ }
+ if len(m.FieldValues) != 1 {
+ t.Errorf("MetricUpdate got %d fields want 1", len(m.FieldValues))
+ }
+
+ switch m.FieldValues[0] {
+ case "weird1":
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ }
+ if uv.Uint64Value != 4 {
+ t.Errorf("%v: Value got %v want 4", m, uv.Uint64Value)
+ }
+ foundWeird1 = true
+ case "weird2":
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ }
+ if uv.Uint64Value != 1 {
+ t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value)
+ }
+ foundWeird2 = true
+ }
+ }
+
+ if !foundWeird1 {
+ t.Errorf("Field value weird1 not found: %+v", emitter)
+ }
+ if !foundWeird2 {
+ t.Errorf("Field value weird2 not found: %+v", emitter)
+ }
+}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 7abc82e1b..28396b0ea 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At
return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil
}
+func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, unix.EBADF
+ }
+
+ if !versionSupportsTmultiGetAttr(c.client.version) {
+ return DefaultMultiGetAttr(c, names)
+ }
+
+ rmultigetattr := Rmultigetattr{}
+ if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil {
+ return nil, err
+ }
+ return rmultigetattr.Stats, nil
+}
+
// StatFS implements File.StatFS.
func (c *clientFile) StatFS() (FSStat, error) {
if atomic.LoadUint32(&c.closed) != 0 {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index c59c6a65b..97e0231d6 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -15,6 +15,8 @@
package p9
import (
+ "errors"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/fd"
)
@@ -72,6 +74,15 @@ type File interface {
// On the server, WalkGetAttr has a read concurrency guarantee.
WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error)
+ // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of
+ // path components similar to Walk(). If the first component name is empty,
+ // the current file is stat'd and included in the results. If the walk reaches
+ // a file that doesn't exist or not a directory, MultiGetAttr returns the
+ // partial result with no error.
+ //
+ // On the server, MultiGetAttr has a read concurrency guarantee.
+ MultiGetAttr(names []string) ([]FullStat, error)
+
// StatFS returns information about the file system associated with
// this file.
//
@@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error {
type DisallowServerCalls struct{}
// Renamed implements File.Renamed.
-func (*clientFile) Renamed(File, string) {
+func (*DisallowServerCalls) Renamed(File, string) {
panic("Renamed should not be called on the client")
}
+
+// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File.
+func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
+ stats := make([]FullStat, 0, len(names))
+ parent := start
+ mask := AttrMaskAll()
+ for i, name := range names {
+ if len(name) == 0 && i == 0 {
+ qid, valid, attr, err := parent.GetAttr(mask)
+ if err != nil {
+ return nil, err
+ }
+ stats = append(stats, FullStat{
+ QID: qid,
+ Valid: valid,
+ Attr: attr,
+ })
+ continue
+ }
+ qids, child, valid, attr, err := parent.WalkGetAttr([]string{name})
+ if parent != start {
+ _ = parent.Close()
+ }
+ if err != nil {
+ if errors.Is(err, unix.ENOENT) {
+ return stats, nil
+ }
+ return nil, err
+ }
+ stats = append(stats, FullStat{
+ QID: qids[0],
+ Valid: valid,
+ Attr: attr,
+ })
+ if attr.Mode.FileType() != ModeDirectory {
+ // Doesn't need to continue if entry is not a dir. Including symlinks
+ // that cannot be followed.
+ _ = child.Close()
+ break
+ }
+ parent = child
+ }
+ if parent != start {
+ _ = parent.Close()
+ }
+ return stats, nil
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 58312d0cc..758e11b13 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message {
}
return rchannel
}
+
+// handle implements handler.handle.
+func (t *Tmultigetattr) handle(cs *connState) message {
+ for i, name := range t.Names {
+ if len(name) == 0 && i == 0 {
+ // Empty name is allowed on the first entry to indicate that the current
+ // FID needs to be included in the result.
+ continue
+ }
+ if err := checkSafeName(name); err != nil {
+ return newErr(err)
+ }
+ }
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(unix.EBADF)
+ }
+ defer ref.DecRef()
+
+ var stats []FullStat
+ if err := ref.safelyRead(func() (err error) {
+ stats, err = ref.file.MultiGetAttr(t.Names)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rmultigetattr{Stats: stats}
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index cf13cbb69..2ff4694c0 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) {
// encode implements encoder.encode.
func (r *Rwalk) encode(b *buffer) {
b.Write16(uint16(len(r.QIDs)))
- for _, q := range r.QIDs {
- q.encode(b)
+ for i := range r.QIDs {
+ r.QIDs[i].encode(b)
}
}
@@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) {
r.Valid.encode(b)
r.Attr.encode(b)
b.Write16(uint16(len(r.QIDs)))
- for _, q := range r.QIDs {
- q.encode(b)
+ for i := range r.QIDs {
+ r.QIDs[i].encode(b)
}
}
@@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string {
return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
+// Tmultigetattr is a multi-getattr request.
+type Tmultigetattr struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Tmultigetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Tmultigetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Tmultigetattr) Type() MsgType {
+ return MsgTmultigetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tmultigetattr) String() string {
+ return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names)
+}
+
+// Rmultigetattr is a multi-getattr response.
+type Rmultigetattr struct {
+ // Stats are the set of FullStat returned for each of the names in the
+ // request.
+ Stats []FullStat
+}
+
+// decode implements encoder.decode.
+func (r *Rmultigetattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Stats = r.Stats[:0]
+ for i := 0; i < int(n); i++ {
+ var fs FullStat
+ fs.decode(b)
+ r.Stats = append(r.Stats, fs)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rmultigetattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Stats)))
+ for i := range r.Stats {
+ r.Stats[i].encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rmultigetattr) Type() MsgType {
+ return MsgRmultigetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rmultigetattr) String() string {
+ return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats)
+}
+
const maxCacheSize = 3
// msgFactory is used to reduce allocations by caching messages for reuse.
@@ -2717,6 +2791,8 @@ func init() {
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} })
msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} })
+ msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} })
+ msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} })
msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 648cf4b49..3d452a0bd 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -402,6 +402,8 @@ const (
MsgRallocate MsgType = 139
MsgTsetattrclunk MsgType = 140
MsgRsetattrclunk MsgType = 141
+ MsgTmultigetattr MsgType = 142
+ MsgRmultigetattr MsgType = 143
MsgTchannel MsgType = 250
MsgRchannel MsgType = 251
)
@@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) {
}
b.Write32(mask)
}
+
+// FullStat is used in the result of a MultiGetAttr call.
+type FullStat struct {
+ QID QID
+ Valid AttrMask
+ Attr Attr
+}
+
+// String implements fmt.Stringer.
+func (f *FullStat) String() string {
+ return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr)
+}
+
+// decode implements encoder.decode.
+func (f *FullStat) decode(b *buffer) {
+ f.QID.decode(b)
+ f.Valid.decode(b)
+ f.Attr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (f *FullStat) encode(b *buffer) {
+ f.QID.encode(b)
+ f.Valid.encode(b)
+ f.Attr.encode(b)
+}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 8d7168ef5..950236162 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 12
+ highestSupportedVersion uint32 = 13
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool {
func versionSupportsTsetattrclunk(v uint32) bool {
return v >= 12
}
+
+// versionSupportsTmultiGetAttr returns true if version v supports
+// the TmultiGetAttr message.
+func versionSupportsTmultiGetAttr(v uint32) bool {
+ return v >= 13
+}
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 6992e1de8..4aecb8007 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -30,6 +30,9 @@ import (
// RefCounter is the interface to be implemented by objects that are reference
// counted.
+//
+// TODO(gvisor.dev/issue/1624): Get rid of most of this package and replace it
+// with refsvfs2.
type RefCounter interface {
// IncRef increments the reference counter on the object.
IncRef()
@@ -181,6 +184,9 @@ func (w *WeakRef) zap() {
// AtomicRefCount keeps a reference count using atomic operations and calls the
// destructor when the count reaches zero.
//
+// Do not use AtomicRefCount for new ref-counted objects! It is deprecated in
+// favor of the refsvfs2 package.
+//
// N.B. To allow the zero-object to be initialized, the count is offset by
// 1, that is, when refCount is n, there are really n+1 references.
//
@@ -215,8 +221,8 @@ type AtomicRefCount struct {
// LeakMode configures the leak checker.
type LeakMode uint32
-// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref
-// counting is gone.
+// TODO(gvisor.dev/issue/1624): Simplify down to two modes (on/off) once vfs1
+// ref counting is gone.
const (
// UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
UninitializedLeakChecking LeakMode = iota
diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD
index 0377c0876..7c1a8c792 100644
--- a/pkg/refsvfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -1,3 +1,5 @@
+# TODO(gvisor.dev/issue/1624): rename this directory/package to "refs" once VFS1
+# is gone and the current refs package can be deleted.
load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template")
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
index 0472eca3f..fb8984dd6 100644
--- a/pkg/refsvfs2/refs_map.go
+++ b/pkg/refsvfs2/refs_map.go
@@ -112,20 +112,27 @@ func logEvent(obj CheckedObject, msg string) {
log.Infof("[%s %p] %s:\n%s", obj.RefType(), obj, msg, refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
}
+// checkOnce makes sure that leak checking is only done once. DoLeakCheck is
+// called from multiple places (which may overlap) to cover different sandbox
+// exit scenarios.
+var checkOnce sync.Once
+
// DoLeakCheck iterates through the live object map and logs a message for each
// object. It is called once no reference-counted objects should be reachable
// anymore, at which point anything left in the map is considered a leak.
func DoLeakCheck() {
if leakCheckEnabled() {
- liveObjectsMu.Lock()
- defer liveObjectsMu.Unlock()
- leaked := len(liveObjects)
- if leaked > 0 {
- msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked)
- for obj := range liveObjects {
- msg += obj.LeakMessage() + "\n"
+ checkOnce.Do(func() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ leaked := len(liveObjects)
+ if leaked > 0 {
+ msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked)
+ for obj := range liveObjects {
+ msg += obj.LeakMessage() + "\n"
+ }
+ log.Warningf(msg)
}
- log.Warningf(msg)
- }
+ })
}
}
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 92d2330cb..41dfd0bf9 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -250,6 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
+ RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
return
}
@@ -321,3 +322,21 @@ func SetCPUIDFaulting(on bool) bool {
func ReadCR2() uintptr {
return readCR2()
}
+
+// kernelMXCSR is the value of the mxcsr register in the Sentry.
+//
+// The MXCSR control configuration is initialized once and never changed. Look
+// at src/cmd/compile/abi-internal.md in the golang sources for more details.
+var kernelMXCSR uint32
+
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+ // Restore the MXCSR control configuration.
+ ldmxcsr(&kernelMXCSR)
+}
+
+func init() {
+ stmxcsr(&kernelMXCSR)
+}
diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go
index 7975e5f92..21db910a2 100644
--- a/pkg/ring0/kernel_arm64.go
+++ b/pkg/ring0/kernel_arm64.go
@@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer())
if switchOpts.Flush {
- FlushTlbByASID(uintptr(switchOpts.UserASID))
+ LocalFlushTlbByASID(uintptr(switchOpts.UserASID))
}
regs := switchOpts.Registers
@@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
return
}
+
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+}
diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go
index 0ec5c3bc5..3e6bb9663 100644
--- a/pkg/ring0/lib_amd64.go
+++ b/pkg/ring0/lib_amd64.go
@@ -61,6 +61,12 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
+// stmxcsr reads the MXCSR control and status register.
+func stmxcsr(addr *uint32)
+
+// ldmxcsr writes to the MXCSR control and status register.
+func ldmxcsr(addr *uint32)
+
// readCR2 reads the current CR2 value.
func readCR2() uintptr
diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s
index 2fe83568a..70a43e79e 100644
--- a/pkg/ring0/lib_amd64.s
+++ b/pkg/ring0/lib_amd64.s
@@ -198,3 +198,15 @@ TEXT ·rdmsr(SB),NOSPLIT,$0-16
MOVL AX, ret+8(FP)
MOVL DX, ret+12(FP)
RET
+
+// stmxcsr reads the MXCSR control and status register.
+TEXT ·stmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ STMXCSR (SI)
+ RET
+
+// ldmxcsr writes to the MXCSR control and status register.
+TEXT ·ldmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ LDMXCSR (SI)
+ RET
diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go
index e44df00a6..5eabd4296 100644
--- a/pkg/ring0/lib_arm64.go
+++ b/pkg/ring0/lib_arm64.go
@@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr)
// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable.
func FlushTlbByASID(asid uintptr)
+// LocalFlushTlbByASID invalidates tlb by ASID.
+func LocalFlushTlbByASID(asid uintptr)
+
// FlushTlbAll invalidates all tlb.
func FlushTlbAll()
diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s
index e39b32841..69ebaf519 100644
--- a/pkg/ring0/lib_arm64.s
+++ b/pkg/ring0/lib_arm64.s
@@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8
DSB $11 // dsb(ish)
RET
+TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ LSL $TLBI_ASID_SHIFT, R1, R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd5088741 // tlbi aside1, x1
+ DSB $11 // dsb(ish)
+ RET
+
TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
DSB $6 // dsb(nshst)
WORD $0xd508871f // __tlbi(vmalle1)
diff --git a/pkg/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD
index f8f160cc6..f855f4d42 100644
--- a/pkg/ring0/pagetables/BUILD
+++ b/pkg/ring0/pagetables/BUILD
@@ -84,8 +84,5 @@ go_test(
":walker_check_arm64",
],
library = ":pagetables",
- deps = [
- "//pkg/hostarch",
- "//pkg/usermem",
- ],
+ deps = ["//pkg/hostarch"],
)
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 3f17fba49..9dac53c80 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -322,3 +322,12 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc
func (p *PageTables) MarkReadOnlyShared() {
p.readOnlyShared = true
}
+
+// PrefaultRootTable touches the root table page to be sure that its physical
+// pages are mapped.
+//
+//go:nosplit
+//go:noinline
+func (p *PageTables) PrefaultRootTable() PTE {
+ return p.root[0]
+}
diff --git a/pkg/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go
index 69320c2fb..2514b9ac5 100644
--- a/pkg/ring0/pagetables/pagetables_arm64_test.go
+++ b/pkg/ring0/pagetables/pagetables_arm64_test.go
@@ -19,24 +19,24 @@ package pagetables
import (
"testing"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func Test2MAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a huge page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
- pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*47)
- pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42)
- pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47)
+ pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: false}, pteSize*42)
+ pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: false}, pmdSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
- {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
- {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}},
- {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}},
+ {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: false}},
+ {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: false}},
})
}
@@ -44,12 +44,12 @@ func Test1GAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a super page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
- pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
- {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}},
+ {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
@@ -57,12 +57,12 @@ func TestSplit1GPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a super page and knock out the middle.
- pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*42)
pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
- {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
@@ -70,11 +70,11 @@ func TestSplit2MPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a huge page and knock out the middle.
- pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*42)
pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
- {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s
index a0cd78f33..d513f16c9 100644
--- a/pkg/safecopy/atomic_amd64.s
+++ b/pkg/safecopy/atomic_amd64.s
@@ -24,12 +24,12 @@ TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24
MOVL DI, sig+20(FP)
RET
-// swapUint32 atomically stores new into *addr and returns (the previous *addr
+// swapUint32 atomically stores new into *ptr and returns (the previous ptr*
// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
// value of old is unspecified, and sig is the number of the signal that was
// received.
//
-// Preconditions: addr must be aligned to a 4-byte boundary.
+// Preconditions: ptr must be aligned to a 4-byte boundary.
//
//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
TEXT ·swapUint32(SB), NOSPLIT, $0-24
@@ -38,12 +38,18 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24
// handleSwapUint32Fault will store a different value in this address.
MOVL $0, sig+20(FP)
- MOVQ addr+0(FP), DI
+ MOVQ ptr+0(FP), DI
MOVL new+8(FP), AX
XCHGL AX, 0(DI)
MOVL AX, old+16(FP)
RET
+// func addrOfSwapUint32() uintptr
+TEXT ·addrOfSwapUint32(SB), $0-8
+ MOVQ $·swapUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleSwapUint64Fault returns the value stored in DI. Control is transferred
// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -54,12 +60,12 @@ TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28
MOVL DI, sig+24(FP)
RET
-// swapUint64 atomically stores new into *addr and returns (the previous *addr
+// swapUint64 atomically stores new into *ptr and returns (the previous *ptr
// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
// value of old is unspecified, and sig is the number of the signal that was
// received.
//
-// Preconditions: addr must be aligned to a 8-byte boundary.
+// Preconditions: ptr must be aligned to a 8-byte boundary.
//
//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
TEXT ·swapUint64(SB), NOSPLIT, $0-28
@@ -68,12 +74,18 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28
// handleSwapUint64Fault will store a different value in this address.
MOVL $0, sig+24(FP)
- MOVQ addr+0(FP), DI
+ MOVQ ptr+0(FP), DI
MOVQ new+8(FP), AX
XCHGQ AX, 0(DI)
MOVQ AX, old+16(FP)
RET
+// func addrOfSwapUint64() uintptr
+TEXT ·addrOfSwapUint64(SB), $0-8
+ MOVQ $·swapUint64(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is
// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the
// signal number stored in DI.
@@ -85,11 +97,11 @@ TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24
RET
// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is
+// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is
// received during the operation, the value of prev is unspecified, and sig is
// the number of the signal that was received.
//
-// Preconditions: addr must be aligned to a 4-byte boundary.
+// Preconditions: ptr must be aligned to a 4-byte boundary.
//
//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
@@ -99,7 +111,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
// address.
MOVL $0, sig+20(FP)
- MOVQ addr+0(FP), DI
+ MOVQ ptr+0(FP), DI
MOVL old+8(FP), AX
MOVL new+12(FP), DX
LOCK
@@ -107,6 +119,12 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
MOVL AX, prev+16(FP)
RET
+// func addrOfCompareAndSwapUint32() uintptr
+TEXT ·addrOfCompareAndSwapUint32(SB), $0-8
+ MOVQ $·compareAndSwapUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleLoadUint32Fault returns the value stored in DI. Control is transferred
// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -117,11 +135,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
MOVL DI, sig+12(FP)
RET
-// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS
// signal is received, the value returned is unspecified, and sig is the number
// of the signal that was received.
//
-// Preconditions: addr must be aligned to a 4-byte boundary.
+// Preconditions: ptr must be aligned to a 4-byte boundary.
//
//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
TEXT ·loadUint32(SB), NOSPLIT, $0-16
@@ -130,7 +148,13 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16
// handleLoadUint32Fault will store a different value in this address.
MOVL $0, sig+12(FP)
- MOVQ addr+0(FP), AX
+ MOVQ ptr+0(FP), AX
MOVL (AX), BX
MOVL BX, val+8(FP)
RET
+
+// func addrOfLoadUint32() uintptr
+TEXT ·addrOfLoadUint32(SB), $0-8
+ MOVQ $·loadUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s
index d58ed71f7..246a049ba 100644
--- a/pkg/safecopy/atomic_arm64.s
+++ b/pkg/safecopy/atomic_arm64.s
@@ -25,7 +25,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24
// handleSwapUint32Fault will store a different value in this address.
MOVW $0, sig+20(FP)
again:
- MOVD addr+0(FP), R0
+ MOVD ptr+0(FP), R0
MOVW new+8(FP), R1
LDAXRW (R0), R2
STLXRW R1, (R0), R3
@@ -33,6 +33,12 @@ again:
MOVW R2, old+16(FP)
RET
+// func addrOfSwapUint32() uintptr
+TEXT ·addrOfSwapUint32(SB), $0-8
+ MOVD $·swapUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleSwapUint64Fault returns the value stored in R1. Control is transferred
// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
// number stored in R1.
@@ -54,7 +60,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28
// handleSwapUint64Fault will store a different value in this address.
MOVW $0, sig+24(FP)
again:
- MOVD addr+0(FP), R0
+ MOVD ptr+0(FP), R0
MOVD new+8(FP), R1
LDAXR (R0), R2
STLXR R1, (R0), R3
@@ -62,6 +68,12 @@ again:
MOVD R2, old+16(FP)
RET
+// func addrOfSwapUint64() uintptr
+TEXT ·addrOfSwapUint64(SB), $0-8
+ MOVD $·swapUint64(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is
// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS,
// with the signal number stored in R1.
@@ -84,7 +96,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
// address.
MOVW $0, sig+20(FP)
- MOVD addr+0(FP), R0
+ MOVD ptr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
again:
@@ -97,6 +109,12 @@ done:
MOVW R3, prev+16(FP)
RET
+// func addrOfCompareAndSwapUint32() uintptr
+TEXT ·addrOfCompareAndSwapUint32(SB), $0-8
+ MOVD $·compareAndSwapUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleLoadUint32Fault returns the value stored in DI. Control is transferred
// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -107,11 +125,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
MOVW R1, sig+12(FP)
RET
-// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS
// signal is received, the value returned is unspecified, and sig is the number
// of the signal that was received.
//
-// Preconditions: addr must be aligned to a 4-byte boundary.
+// Preconditions: ptr must be aligned to a 4-byte boundary.
//
//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
TEXT ·loadUint32(SB), NOSPLIT, $0-16
@@ -120,7 +138,13 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16
// handleLoadUint32Fault will store a different value in this address.
MOVW $0, sig+12(FP)
- MOVD addr+0(FP), R0
+ MOVD ptr+0(FP), R0
LDARW (R0), R1
MOVW R1, val+8(FP)
RET
+
+// func addrOfLoadUint32() uintptr
+TEXT ·addrOfLoadUint32(SB), $0-8
+ MOVD $·loadUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s
index 64cf32f05..4abaecaff 100644
--- a/pkg/safecopy/memclr_amd64.s
+++ b/pkg/safecopy/memclr_amd64.s
@@ -145,3 +145,9 @@ _129through256:
MOVOU X0, -32(DI)(BX*1)
MOVOU X0, -16(DI)(BX*1)
RET
+
+// func addrOfMemclr() uintptr
+TEXT ·addrOfMemclr(SB), $0-8
+ MOVQ $·memclr(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s
index 7361b9067..c789bfeb3 100644
--- a/pkg/safecopy/memclr_arm64.s
+++ b/pkg/safecopy/memclr_arm64.s
@@ -72,3 +72,9 @@ head_loop:
CMP $16, R1
BLT tail_zero
B aligned_to_16
+
+// func addrOfMemclr() uintptr
+TEXT ·addrOfMemclr(SB), $0-8
+ MOVD $·memclr(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
index 00b46c18f..37316b2f5 100644
--- a/pkg/safecopy/memcpy_amd64.s
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -51,8 +51,8 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36
// handleMemcpyFault will store a different value in this address.
MOVL $0, sig+32(FP)
- MOVQ to+0(FP), DI
- MOVQ from+8(FP), SI
+ MOVQ dst+0(FP), DI
+ MOVQ src+8(FP), SI
MOVQ n+16(FP), BX
tail:
@@ -217,3 +217,9 @@ move_129through256:
MOVOU -16(SI)(BX*1), X15
MOVOU X15, -16(DI)(BX*1)
RET
+
+// func addrOfMemcpy() uintptr
+TEXT ·addrOfMemcpy(SB), $0-8
+ MOVQ $·memcpy(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s
index e7e541565..50f5b754b 100644
--- a/pkg/safecopy/memcpy_arm64.s
+++ b/pkg/safecopy/memcpy_arm64.s
@@ -33,8 +33,8 @@ TEXT ·memcpy(SB), NOSPLIT, $-8-36
// handleMemcpyFault will store a different value in this address.
MOVW $0, sig+32(FP)
- MOVD to+0(FP), R3
- MOVD from+8(FP), R4
+ MOVD dst+0(FP), R3
+ MOVD src+8(FP), R4
MOVD n+16(FP), R5
CMP $0, R5
BNE check
@@ -76,3 +76,9 @@ forwardtailloop:
CMP R3, R9
BNE forwardtailloop
RET
+
+// func addrOfMemcpy() uintptr
+TEXT ·addrOfMemcpy(SB), $0-8
+ MOVD $·memcpy(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index 1e0af5889..df63dd5f1 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -18,7 +18,6 @@ package safecopy
import (
"fmt"
- "reflect"
"runtime"
"golang.org/x/sys/unix"
@@ -91,6 +90,11 @@ var (
// signals.
func signalHandler()
+// addrOfSignalHandler returns the start address of signalHandler.
+//
+// See comment on addrOfMemcpy for more details.
+func addrOfSignalHandler() uintptr
+
// FindEndAddress returns the end address (one byte beyond the last) of the
// function that contains the specified address (begin).
func FindEndAddress(begin uintptr) uintptr {
@@ -111,26 +115,26 @@ func initializeAddresses() {
// The following functions are written in assembly language, so they won't
// be inlined by the existing compiler/linker. Tests will fail if this
// assumption is violated.
- memcpyBegin = reflect.ValueOf(memcpy).Pointer()
+ memcpyBegin = addrOfMemcpy()
memcpyEnd = FindEndAddress(memcpyBegin)
- memclrBegin = reflect.ValueOf(memclr).Pointer()
+ memclrBegin = addrOfMemclr()
memclrEnd = FindEndAddress(memclrBegin)
- swapUint32Begin = reflect.ValueOf(swapUint32).Pointer()
+ swapUint32Begin = addrOfSwapUint32()
swapUint32End = FindEndAddress(swapUint32Begin)
- swapUint64Begin = reflect.ValueOf(swapUint64).Pointer()
+ swapUint64Begin = addrOfSwapUint64()
swapUint64End = FindEndAddress(swapUint64Begin)
- compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer()
+ compareAndSwapUint32Begin = addrOfCompareAndSwapUint32()
compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin)
- loadUint32Begin = reflect.ValueOf(loadUint32).Pointer()
+ loadUint32Begin = addrOfLoadUint32()
loadUint32End = FindEndAddress(loadUint32Begin)
}
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(unix.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil {
+ if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(unix.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil {
+ if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) {
diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
index 611f36253..55743e69c 100644
--- a/pkg/safecopy/safecopy_test.go
+++ b/pkg/safecopy/safecopy_test.go
@@ -19,8 +19,6 @@ import (
"fmt"
"io/ioutil"
"math/rand"
- "os"
- "runtime/debug"
"testing"
"unsafe"
@@ -568,63 +566,3 @@ func TestCompareAndSwapUint32BusError(t *testing.T) {
}
})
}
-
-func testCopy(dst, src []byte) (panicked bool) {
- defer func() {
- if r := recover(); r != nil {
- panicked = true
- }
- }()
- debug.SetPanicOnFault(true)
- copy(dst, src)
- return
-}
-
-func TestSegVOnMemmove(t *testing.T) {
- // Test that SIGSEGVs received by runtime.memmove when *not* doing
- // CopyIn or CopyOut work gets propagated to the runtime.
- const bufLen = pageSize
- a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE)
- if err != nil {
- t.Fatalf("Mmap failed: %v", err)
-
- }
- defer unix.Munmap(a)
- b := randBuf(bufLen)
-
- if !testCopy(b, a) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-
- if !testCopy(a, b) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-}
-
-func TestSigbusOnMemmove(t *testing.T) {
- // Test that SIGBUS received by runtime.memmove when *not* doing
- // CopyIn or CopyOut work gets propagated to the runtime.
- const bufLen = pageSize
- f, err := ioutil.TempFile("", "sigbus_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- os.Remove(f.Name())
- defer f.Close()
-
- a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
- if err != nil {
- t.Fatalf("Mmap failed: %v", err)
-
- }
- defer unix.Munmap(a)
- b := randBuf(bufLen)
-
- if !testCopy(b, a) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-
- if !testCopy(a, b) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-}
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index a075cf88e..efbc2ddc1 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -89,6 +89,18 @@ func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig
//go:noescape
func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+// Return the start address of the functions above.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfMemcpy() uintptr
+func addrOfMemclr() uintptr
+func addrOfSwapUint32() uintptr
+func addrOfSwapUint64() uintptr
+func addrOfCompareAndSwapUint32() uintptr
+func addrOfLoadUint32() uintptr
+
// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s
index 475ae48e9..0b5e8df66 100644
--- a/pkg/safecopy/sighandler_amd64.s
+++ b/pkg/safecopy/sighandler_amd64.s
@@ -131,3 +131,9 @@ handle_fault:
MOVL DI, REG_RDI(DX)
RET
+
+// func addrOfSignalHandler() uintptr
+TEXT ·addrOfSignalHandler(SB), $0-8
+ MOVQ $·signalHandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s
index 53e4ac2c1..41ed70ff9 100644
--- a/pkg/safecopy/sighandler_arm64.s
+++ b/pkg/safecopy/sighandler_arm64.s
@@ -141,3 +141,9 @@ handle_fault:
MOVW R0, REG_R1(R2)
RET
+
+// func addrOfSignalHandler() uintptr
+TEXT ·addrOfSignalHandler(SB), $0-8
+ MOVD $·signalHandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD
index 3fda3a9cc..2c7cc8769 100644
--- a/pkg/safemem/BUILD
+++ b/pkg/safemem/BUILD
@@ -14,6 +14,7 @@ go_library(
deps = [
"//pkg/gohacks",
"//pkg/safecopy",
+ "//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index 93879bb4f..4af534385 100644
--- a/pkg/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/sync"
)
// A Block is a range of contiguous bytes, similar to []byte but with the
@@ -223,8 +224,22 @@ func Copy(dst, src Block) (int, error) {
func Zero(dst Block) (int, error) {
if !dst.needSafecopy {
bs := dst.ToSlice()
- for i := range bs {
- bs[i] = 0
+ if !sync.RaceEnabled {
+ // If the race detector isn't enabled, the golang
+ // compiler replaces the next loop with memclr
+ // (https://github.com/golang/go/issues/5373).
+ for i := range bs {
+ bs[i] = 0
+ }
+ } else {
+ bsLen := len(bs)
+ if bsLen == 0 {
+ return 0, nil
+ }
+ bs[0] = 0
+ for i := 1; i < bsLen; i *= 2 {
+ copy(bs[i:], bs[:i])
+ }
}
return len(bs), nil
}
diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go
index 1e9625bee..f0ba26736 100644
--- a/pkg/sentry/arch/fpu/fpu_amd64.go
+++ b/pkg/sentry/arch/fpu/fpu_amd64.go
@@ -219,6 +219,11 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid
return copy(*s, f), nil
}
+// SetMXCSR sets the MXCSR control/status register in the state.
+func (s *State) SetMXCSR(mxcsr uint32) {
+ hostarch.ByteOrder.PutUint32((*s)[mxcsrOffset:], mxcsr)
+}
+
// BytePointer returns a pointer to the first byte of the state.
//
//go:nosplit
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 1929e41cd..49c53452a 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
// "/dev/zero (deleted)".
opts.Offset = 0
opts.MappingIdentity = &fd.vfsfd
+ opts.SentryOwnedContent = true
opts.MappingIdentity.IncRef()
return nil
}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 0b3d0617f..46a2dc47d 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -384,8 +384,16 @@ func (c *ConnectedEndpoint) CloseUnread() {}
// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
- // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
- // sockets.
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix
+ // domain sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
+// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
+func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_RCVBUF for host backed unix
+ // domain sockets. Receive buffer does not have any effect for unix
+ // sockets and we claim to be the same as send buffer.
return atomic.LoadInt64(&c.sndbuf)
}
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
new file mode 100644
index 000000000..37efb641a
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -0,0 +1,48 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dir_refs",
+ out = "dir_refs.go",
+ package = "cgroupfs",
+ prefix = "dir",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "dir",
+ },
+)
+
+go_library(
+ name = "cgroupfs",
+ srcs = [
+ "base.go",
+ "cgroupfs.go",
+ "cpu.go",
+ "cpuacct.go",
+ "cpuset.go",
+ "dir_refs.go",
+ "job.go",
+ "memory.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/coverage",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
new file mode 100644
index 000000000..0f54888d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -0,0 +1,261 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// controllerCommon implements kernel.CgroupController.
+//
+// Must call init before use.
+//
+// +stateify savable
+type controllerCommon struct {
+ ty kernel.CgroupControllerType
+ fs *filesystem
+}
+
+func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) {
+ c.ty = ty
+ c.fs = fs
+}
+
+// Type implements kernel.CgroupController.Type.
+func (c *controllerCommon) Type() kernel.CgroupControllerType {
+ return kernel.CgroupControllerType(c.ty)
+}
+
+// HierarchyID implements kernel.CgroupController.HierarchyID.
+func (c *controllerCommon) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// NumCgroups implements kernel.CgroupController.NumCgroups.
+func (c *controllerCommon) NumCgroups() uint64 {
+ return atomic.LoadUint64(&c.fs.numCgroups)
+}
+
+// Enabled implements kernel.CgroupController.Enabled.
+//
+// Controllers are currently always enabled.
+func (c *controllerCommon) Enabled() bool {
+ return true
+}
+
+// Filesystem implements kernel.CgroupController.Filesystem.
+func (c *controllerCommon) Filesystem() *vfs.Filesystem {
+ return c.fs.VFSFilesystem()
+}
+
+// RootCgroup implements kernel.CgroupController.RootCgroup.
+func (c *controllerCommon) RootCgroup() kernel.Cgroup {
+ return c.fs.rootCgroup()
+}
+
+// controller is an interface for common functionality related to all cgroups.
+// It is an extension of the public cgroup interface, containing cgroup
+// functionality private to cgroupfs.
+type controller interface {
+ kernel.CgroupController
+
+ // AddControlFiles should extend the contents map with inodes representing
+ // control files defined by this controller.
+ AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode)
+}
+
+// cgroupInode implements kernel.CgroupImpl and kernfs.Inode.
+//
+// +stateify savable
+type cgroupInode struct {
+ dir
+ fs *filesystem
+
+ // ts is the list of tasks in this cgroup. The kernel is responsible for
+ // removing tasks from this list before they're destroyed, so any tasks on
+ // this list are always valid.
+ //
+ // ts, and cgroup membership in general is protected by fs.tasksMu.
+ ts map[*kernel.Task]struct{}
+}
+
+var _ kernel.CgroupImpl = (*cgroupInode)(nil)
+
+func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
+ c := &cgroupInode{
+ fs: fs,
+ ts: make(map[*kernel.Task]struct{}),
+ }
+
+ contents := make(map[string]kernfs.Inode)
+ contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c})
+ contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c})
+
+ for _, ctl := range fs.controllers {
+ ctl.AddControlFiles(ctx, creds, c, contents)
+ }
+
+ c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555))
+ c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ c.dir.InitRefs()
+ c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents))
+
+ atomic.AddUint64(&fs.numCgroups, 1)
+
+ return c
+}
+
+func (c *cgroupInode) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// Controllers implements kernel.CgroupImpl.Controllers.
+func (c *cgroupInode) Controllers() []kernel.CgroupController {
+ return c.fs.kcontrollers
+}
+
+// Enter implements kernel.CgroupImpl.Enter.
+func (c *cgroupInode) Enter(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ c.ts[t] = struct{}{}
+ c.fs.tasksMu.Unlock()
+}
+
+// Leave implements kernel.CgroupImpl.Leave.
+func (c *cgroupInode) Leave(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ delete(c.ts, t)
+ c.fs.tasksMu.Unlock()
+}
+
+func sortTIDs(tids []kernel.ThreadID) {
+ sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] })
+}
+
+// +stateify savable
+type cgroupProcsData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ pgids := make(map[kernel.ThreadID]struct{})
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ // Map dedups pgid, since iterating over all tasks produces multiple
+ // entries for the group leaders.
+ if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 {
+ pgids[pgid] = struct{}{}
+ }
+ }
+
+ pgidList := make([]kernel.ThreadID, 0, len(pgids))
+ for pgid, _ := range pgids {
+ pgidList = append(pgidList, pgid)
+ }
+ sortTIDs(pgidList)
+
+ for _, pgid := range pgidList {
+ fmt.Fprintf(buf, "%d\n", pgid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
+
+// +stateify savable
+type tasksData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ var pids []kernel.ThreadID
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ if pid := currPidns.IDOfTask(task); pid != 0 {
+ pids = append(pids, pid)
+ }
+ }
+ sortTIDs(pids)
+
+ for _, pid := range pids {
+ fmt.Fprintf(buf, "%d\n", pid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
+
+// parseInt64FromString interprets src as string encoding a int64 value, and
+// returns the parsed value.
+func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset int64) (val, len int64, err error) {
+ const maxInt64StrLen = 20 // i.e. len(fmt.Sprintf("%d", math.MinInt64)) == 20
+
+ t := kernel.TaskFromContext(ctx)
+ src = src.DropFirst64(offset)
+
+ buf := t.CopyScratchBuffer(maxInt64StrLen)
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, int64(n), err
+ }
+ buf = buf[:n]
+
+ val, err = strconv.ParseInt(string(buf), 10, 64)
+ if err != nil {
+ // Note: This also handles zero-len writes if offset is beyond the end
+ // of src, or src is empty.
+ ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err)
+ return 0, int64(n), syserror.EINVAL
+ }
+
+ return val, int64(n), nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
new file mode 100644
index 000000000..bd3e69757
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -0,0 +1,425 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cgroupfs implements cgroupfs.
+//
+// A cgroup is a collection of tasks on the system, organized into a tree-like
+// structure similar to a filesystem directory tree. In fact, each cgroup is
+// represented by a directory on cgroupfs, and is manipulated through control
+// files in the directory.
+//
+// All cgroups on a system are organized into hierarchies. Hierarchies are a
+// distinct tree of cgroups, with a common set of controllers. One or more
+// cgroupfs mounts may point to each hierarchy. These mounts provide a common
+// view into the same tree of cgroups.
+//
+// A controller (also known as a "resource controller", or a cgroup "subsystem")
+// determines the behaviour of each cgroup.
+//
+// In addition to cgroupfs, the kernel has a cgroup registry that tracks
+// system-wide state related to cgroups such as active hierarchies and the
+// controllers associated with them.
+//
+// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between
+// cgroupfs dentries and inodes.
+//
+// # Synchronization
+//
+// Cgroup hierarchy creation and destruction is protected by the
+// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the
+// filesystem associated with it, and the root cgroup for the hierarchy are
+// immutable.
+//
+// Membership of tasks within cgroups is protected by
+// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're
+// in, and this list is protected by Task.mu.
+//
+// Lock order:
+//
+// kernel.CgroupRegistry.mu
+// cgroupfs.filesystem.mu
+// Task.mu
+// cgroupfs.filesystem.tasksMu.
+package cgroupfs
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // Name is the default filesystem name.
+ Name = "cgroup"
+ readonlyFileMode = linux.FileMode(0444)
+ writableFileMode = linux.FileMode(0644)
+ defaultMaxCachedDentries = uint64(1000)
+)
+
+const (
+ controllerCPU = kernel.CgroupControllerType("cpu")
+ controllerCPUAcct = kernel.CgroupControllerType("cpuacct")
+ controllerCPUSet = kernel.CgroupControllerType("cpuset")
+ controllerJob = kernel.CgroupControllerType("job")
+ controllerMemory = kernel.CgroupControllerType("memory")
+)
+
+var allControllers = []kernel.CgroupControllerType{
+ controllerCPU,
+ controllerCPUAcct,
+ controllerCPUSet,
+ controllerJob,
+ controllerMemory,
+}
+
+// SupportedMountOptions is the set of supported mount options for cgroupfs.
+var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "job", "memory"}
+
+// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
+type FilesystemType struct{}
+
+// InternalData contains internal data passed in to the cgroupfs mount via
+// vfs.GetFilesystemOptions.InternalData.
+//
+// +stateify savable
+type InternalData struct {
+ DefaultControlValues map[string]int64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has
+ // the value kernel.InvalidCgroupHierarchyID until the FS is fully
+ // initialized.
+ //
+ // hierarchyID is immutable after initialization.
+ hierarchyID uint32
+
+ // controllers and kcontrollers are both the list of controllers attached to
+ // this cgroupfs. Both lists are the same set of controllers, but typecast
+ // to different interfaces for convenience. Both must stay in sync, and are
+ // immutable.
+ controllers []controller
+ kcontrollers []kernel.CgroupController
+
+ numCgroups uint64 // Protected by atomic ops.
+
+ root *kernfs.Dentry
+
+ // tasksMu serializes task membership changes across all cgroups within a
+ // filesystem.
+ tasksMu sync.RWMutex `state:"nosave"`
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ var wantControllers []kernel.CgroupControllerType
+ if _, ok := mopts["cpu"]; ok {
+ delete(mopts, "cpu")
+ wantControllers = append(wantControllers, controllerCPU)
+ }
+ if _, ok := mopts["cpuacct"]; ok {
+ delete(mopts, "cpuacct")
+ wantControllers = append(wantControllers, controllerCPUAcct)
+ }
+ if _, ok := mopts["cpuset"]; ok {
+ delete(mopts, "cpuset")
+ wantControllers = append(wantControllers, controllerCPUSet)
+ }
+ if _, ok := mopts["job"]; ok {
+ delete(mopts, "job")
+ wantControllers = append(wantControllers, controllerJob)
+ }
+ if _, ok := mopts["memory"]; ok {
+ delete(mopts, "memory")
+ wantControllers = append(wantControllers, controllerMemory)
+ }
+ if _, ok := mopts["all"]; ok {
+ if len(wantControllers) > 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ delete(mopts, "all")
+ wantControllers = allControllers
+ }
+
+ if len(wantControllers) == 0 {
+ // Specifying no controllers implies all controllers.
+ wantControllers = allControllers
+ }
+
+ if len(mopts) != 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ // "It is not possible to mount the same controller against multiple
+ // cgroup hierarchies. For example, it is not possible to mount both
+ // the cpu and cpuacct controllers against one hierarchy, and to mount
+ // the cpu controller alone against another hierarchy." - man cgroups(7)
+ //
+ // Is there a hierarchy available with all the controllers we want? If so,
+ // this mount is a view into the same hierarchy.
+ //
+ // Note: we're guaranteed to have at least one requested controller, since
+ // no explicit controller name implies all controllers.
+ if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil {
+ fs := vfsfs.Impl().(*filesystem)
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID)
+ fs.root.IncRef()
+ return vfsfs, fs.root.VFSDentry(), nil
+ }
+
+ // No existing hierarchy with the exactly controllers found. Make a new
+ // one. Note that it's possible this mount creation is unsatisfiable, if one
+ // or more of the requested controllers are already on existing
+ // hierarchies. We'll find out about such collisions when we try to register
+ // the new hierarchy later.
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.MaxCachedDentries = maxCachedDentries
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ var defaults map[string]int64
+ if opts.InternalData != nil {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults)
+ defaults = opts.InternalData.(*InternalData).DefaultControlValues
+ }
+
+ for _, ty := range wantControllers {
+ var c controller
+ switch ty {
+ case controllerCPU:
+ c = newCPUController(fs, defaults)
+ case controllerCPUAcct:
+ c = newCPUAcctController(fs)
+ case controllerCPUSet:
+ c = newCPUSetController(fs)
+ case controllerJob:
+ c = newJobController(fs)
+ case controllerMemory:
+ c = newMemoryController(fs, defaults)
+ default:
+ panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty))
+ }
+ fs.controllers = append(fs.controllers, c)
+ }
+
+ if len(defaults) != 0 {
+ // Internal data is always provided at sentry startup and unused values
+ // indicate a problem with the sandbox config. Fail fast.
+ panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults))
+ }
+
+ // Controllers usually appear in alphabetical order when displayed. Sort it
+ // here now, so it never needs to be sorted elsewhere.
+ sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() })
+ fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers))
+ for _, c := range fs.controllers {
+ fs.kcontrollers = append(fs.kcontrollers, c)
+ }
+
+ root := fs.newCgroupInode(ctx, creds)
+ var rootD kernfs.Dentry
+ rootD.InitRoot(&fs.Filesystem, root)
+ fs.root = &rootD
+
+ // Register controllers. The registry may be modified concurrently, so if we
+ // get an error, we raced with someone else who registered the same
+ // controllers first.
+ hid, err := r.Register(fs.kcontrollers)
+ if err != nil {
+ ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err)
+ rootD.DecRef(ctx)
+ fs.VFSFilesystem().DecRef(ctx)
+ return nil, nil, syserror.EBUSY
+ }
+ fs.hierarchyID = hid
+
+ // Move all existing tasks to the root of the new hierarchy.
+ k.PopulateNewCgroupHierarchy(fs.rootCgroup())
+
+ return fs.VFSFilesystem(), rootD.VFSDentry(), nil
+}
+
+func (fs *filesystem) rootCgroup() kernel.Cgroup {
+ return kernel.Cgroup{
+ Dentry: fs.root,
+ CgroupImpl: fs.root.Inode().(kernel.CgroupImpl),
+ }
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ if fs.hierarchyID != kernel.InvalidCgroupHierarchyID {
+ k.ReleaseCgroupHierarchy(fs.hierarchyID)
+ r.Unregister(fs.hierarchyID)
+ }
+
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ var cnames []string
+ for _, c := range fs.controllers {
+ cnames = append(cnames, string(c.Type()))
+ }
+ return strings.Join(cnames, ",")
+}
+
+// +stateify savable
+type implStatFS struct{}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// dir implements kernfs.Inode for a generic cgroup resource controller
+// directory. Specific controllers extend this to add their own functionality.
+//
+// +stateify savable
+type dir struct {
+ dirRefs
+ kernfs.InodeAlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir.
+ kernfs.OrderedChildren
+ implStatFS
+
+ locks vfs.FileLocks
+}
+
+// Keep implements kernfs.Inode.Keep.
+func (*dir) Keep() bool {
+ return true
+}
+
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+ SeekEnd: kernfs.SeekEndStaticEntries,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// DecRef implements kernfs.Inode.DecRef.
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
+}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// controllerFile represents a generic control file that appears within a cgroup
+// directory.
+//
+// +stateify savable
+type controllerFile struct {
+ kernfs.DynamicBytesFile
+}
+
+func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode)
+ return f
+}
+
+func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode)
+ return f
+}
+
+// staticControllerFile represents a generic control file that appears within a
+// cgroup directory which always returns the same data when read.
+// staticControllerFiles are not writable.
+//
+// +stateify savable
+type staticControllerFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+// Note: We let the caller provide the mode so that static files may be used to
+// fake both readable and writable control files. However, static files are
+// effectively readonly, as attempting to write to them will return EIO
+// regardless of the mode.
+func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode {
+ f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode)
+ return f
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go
new file mode 100644
index 000000000..24d86a277
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpuController struct {
+ controllerCommon
+
+ // CFS bandwidth control parameters, values in microseconds.
+ cfsPeriod int64
+ cfsQuota int64
+
+ // CPU shares, values should be (num core * 1024).
+ shares int64
+}
+
+var _ controller = (*cpuController)(nil)
+
+func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController {
+ // Default values for controller parameters from Linux.
+ c := &cpuController{
+ cfsPeriod: 100000,
+ cfsQuota: -1,
+ shares: 1024,
+ }
+
+ if val, ok := defaults["cpu.cfs_period_us"]; ok {
+ c.cfsPeriod = val
+ delete(defaults, "cpu.cfs_period_us")
+ }
+ if val, ok := defaults["cpu.cfs_quota_us"]; ok {
+ c.cfsQuota = val
+ delete(defaults, "cpu.cfs_quota_us")
+ }
+ if val, ok := defaults["cpu.shares"]; ok {
+ c.shares = val
+ delete(defaults, "cpu.shares")
+ }
+
+ c.controllerCommon.init(controllerCPU, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod))
+ contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota))
+ contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares))
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
new file mode 100644
index 000000000..d4104a00e
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type cpuacctController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpuacctController)(nil)
+
+func newCPUAcctController(fs *filesystem) *cpuacctController {
+ c := &cpuacctController{}
+ c.controllerCommon.init(controllerCPUAcct, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) {
+ cpuacctCG := &cpuacctCgroup{cg}
+ contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG})
+ contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG})
+ contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG})
+ contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG})
+}
+
+// +stateify savable
+type cpuacctCgroup struct {
+ *cgroupInode
+}
+
+func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats {
+ var cs usage.CPUStats
+ c.fs.tasksMu.RLock()
+ // Note: This isn't very accurate, since the tasks are potentially
+ // still running as we accumulate their stats.
+ for t := range c.ts {
+ cs.Accumulate(t.CPUStats())
+ }
+ c.fs.tasksMu.RUnlock()
+ return cs
+}
+
+// +stateify savable
+type cpuacctStatData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime))
+ fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime))
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageUserData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageSysData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds())
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
new file mode 100644
index 000000000..ac547f8e2
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
@@ -0,0 +1,39 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpusetController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpusetController)(nil)
+
+func newCPUSetController(fs *filesystem) *cpusetController {
+ c := &cpusetController{}
+ c.controllerCommon.init(controllerCPUSet, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ // This controller is currently intentionally empty.
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go
new file mode 100644
index 000000000..48919c338
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/job.go
@@ -0,0 +1,64 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// +stateify savable
+type jobController struct {
+ controllerCommon
+ id int64
+}
+
+var _ controller = (*jobController)(nil)
+
+func newJobController(fs *filesystem) *jobController {
+ c := &jobController{}
+ c.controllerCommon.init(controllerJob, fs)
+ return c
+}
+
+func (c *jobController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["job.id"] = c.fs.newControllerWritableFile(ctx, creds, &jobIDData{c: c})
+}
+
+// +stateify savable
+type jobIDData struct {
+ c *jobController
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.c.id)
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ val, n, err := parseInt64FromString(ctx, src, offset)
+ if err != nil {
+ return n, err
+ }
+ d.c.id = val
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go
new file mode 100644
index 000000000..485c98376
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/memory.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type memoryController struct {
+ controllerCommon
+
+ limitBytes int64
+}
+
+var _ controller = (*memoryController)(nil)
+
+func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController {
+ c := &memoryController{
+ // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which
+ // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact
+ // value isn't very important.
+ limitBytes: math.MaxInt64,
+ }
+ if val, ok := defaults["memory.limit_in_bytes"]; ok {
+ c.limitBytes = val
+ delete(defaults, "memory.limit_in_bytes")
+ }
+ c.controllerCommon.init(controllerMemory, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{})
+ contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes))
+}
+
+// +stateify savable
+type memoryUsageInBytesData struct{}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/183151557): This is a giant hack, we're using system-wide
+ // accounting since we know there is only one cgroup.
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
+ mf.UpdateUsage()
+ _, totalBytes := usage.MemoryAccounting.Copy()
+
+ fmt.Fprintf(buf, "%d\n", totalBytes)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 7b1eec3da..2dbc6bfd5 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -46,7 +46,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fd",
"//pkg/fspath",
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 6d5258a9b..52879f871 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"host_named_pipe.go",
"p9file.go",
"regular_file.go",
+ "revalidate.go",
"save_restore.go",
"socket.go",
"special_file.go",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 43c3c5a2d..97ce80853 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
return ds
}
+// Precondition: !parent.isSynthetic() && !child.isSynthetic().
+func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) {
+ // The new child was added to parent and took a ref on the parent (hence
+ // parent can be removed from cache). A new child has 0 refs for now. So
+ // checkCachingLocked() should be called on both. Call it first on the parent
+ // as it may create space in the cache for child to be inserted - hence
+ // avoiding a cache eviction.
+ *ds = appendDentry(*ds, parent)
+ *ds = appendDentry(*ds, child)
+}
+
// Preconditions: ds != nil.
func putDentrySlice(ds *[]*dentry) {
// Allow dentries to be GC'd.
@@ -141,21 +152,8 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **
return
}
ds := **dsp
- // Only go through calling dentry.checkCachingLocked() (which requires
- // re-locking renameMu) if we actually have any dentries with zero refs.
- checkAny := false
- for i := range ds {
- if atomic.LoadInt64(&ds[i].refs) == 0 {
- checkAny = true
- break
- }
- }
- if checkAny {
- fs.renameMu.Lock()
- for _, d := range ds {
- d.checkCachingLocked(ctx)
- }
- fs.renameMu.Unlock()
+ for _, d := range ds {
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
putDentrySlice(*dsp)
}
@@ -166,7 +164,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]
return
}
for _, d := range **ds {
- d.checkCachingLocked(ctx)
+ d.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
fs.renameMu.Unlock()
putDentrySlice(*ds)
@@ -182,165 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
// * !rp.Done().
-// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up
-// to date.
+// * If !d.cachedMetadataAuthoritative(), then d and all children that are
+// part of rp must have been revalidated.
//
// Postconditions: The returned dentry's cached metadata is up to date.
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) {
if !d.isDir() {
- return nil, syserror.ENOTDIR
+ return nil, false, syserror.ENOTDIR
}
if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
- return nil, err
+ return nil, false, err
}
+ followedSymlink := false
afterSymlink:
name := rp.Component()
if name == "." {
rp.Advance()
- return d, nil
+ return d, followedSymlink, nil
}
if name == ".." {
if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
- return nil, err
+ return nil, false, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return d, nil
- }
- // We must assume that d.parent is correct, because if d has been moved
- // elsewhere in the remote filesystem so that its parent has changed,
- // we have no way of determining its new parent's location in the
- // filesystem.
- //
- // Call rp.CheckMount() before updating d.parent's metadata, since if
- // we traverse to another mount then d.parent's metadata is irrelevant.
- if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
- return nil, err
+ return d, followedSymlink, nil
}
- if d != d.parent && !d.cachedMetadataAuthoritative() {
- if err := d.parent.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, false, err
}
rp.Advance()
- return d.parent, nil
+ return d.parent, followedSymlink, nil
}
- child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds)
+ child, err := fs.getChildLocked(ctx, d, name, ds)
if err != nil {
- return nil, err
- }
- if child == nil {
- return nil, syserror.ENOENT
+ return nil, false, err
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
- return nil, err
+ return nil, false, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx, rp.Mount())
if err != nil {
- return nil, err
+ return nil, false, err
}
if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ return nil, false, err
}
+ followedSymlink = true
goto afterSymlink // don't check the current directory again
}
rp.Advance()
- return child, nil
+ return child, followedSymlink, nil
}
// getChildLocked returns a dentry representing the child of parent with the
-// given name. If no such child exists, getChildLocked returns (nil, nil).
+// given name. Returns ENOENT if the child doesn't exist.
//
// Preconditions:
// * fs.renameMu must be locked.
// * parent.dirMu must be locked.
// * parent.isDir().
// * name is not "." or "..".
-//
-// Postconditions: If getChildLocked returns a non-nil dentry, its cached
-// metadata is up to date.
-func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+// * dentry at name has been revalidated
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if len(name) > maxFilenameLen {
return nil, syserror.ENAMETOOLONG
}
- child, ok := parent.children[name]
- if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() {
- // Whether child is nil or not, it is cached information that is
- // assumed to be correct.
+ if child, ok := parent.children[name]; ok || parent.isSynthetic() {
+ if child == nil {
+ return nil, syserror.ENOENT
+ }
return child, nil
}
- // We either don't have cached information or need to verify that it's
- // still correct, either of which requires a remote lookup. Check if this
- // name is valid before performing the lookup.
- return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds)
-}
-// Preconditions: Same as getChildLocked, plus:
-// * !parent.isSynthetic().
-func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
- if child != nil {
- // Need to lock child.metadataMu because we might be updating child
- // metadata. We need to hold the lock *before* getting metadata from the
- // server and release it after updating local metadata.
- child.metadataMu.Lock()
- }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
- if err != nil && err != syserror.ENOENT {
- if child != nil {
- child.metadataMu.Unlock()
+ if err != nil {
+ if err == syserror.ENOENT {
+ parent.cacheNegativeLookupLocked(name)
}
return nil, err
}
- if child != nil {
- if !file.isNil() && qid.Path == child.qidPath {
- // The file at this path hasn't changed. Just update cached metadata.
- file.close(ctx)
- child.updateFromP9AttrsLocked(attrMask, &attr)
- child.metadataMu.Unlock()
- return child, nil
- }
- child.metadataMu.Unlock()
- if file.isNil() && child.isSynthetic() {
- // We have a synthetic file, and no remote file has arisen to
- // replace it.
- return child, nil
- }
- // The file at this path has changed or no longer exists. Mark the
- // dentry invalidated, and re-evaluate its caching status (i.e. if it
- // has 0 references, drop it). Wait to update parent.children until we
- // know what to replace the existing dentry with (i.e. one of the
- // returns below), to avoid a redundant map access.
- vfsObj.InvalidateDentry(ctx, &child.vfsd)
- if child.isSynthetic() {
- // Normally we don't mark invalidated dentries as deleted since
- // they may still exist (but at a different path), and also for
- // consistency with Linux. However, synthetic files are guaranteed
- // to become unreachable if their dentries are invalidated, so
- // treat their invalidation as deletion.
- child.setDeleted()
- parent.syntheticChildren--
- child.decRefNoCaching()
- parent.dirents = nil
- }
- *ds = appendDentry(*ds, child)
- }
- if file.isNil() {
- // No file exists at this path now. Cache the negative lookup if
- // allowed.
- parent.cacheNegativeLookupLocked(name)
- return nil, nil
- }
+
// Create a new dentry representing the file.
- child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
if err != nil {
file.close(ctx)
delete(parent.children, name)
return nil, err
}
parent.cacheNewChildLocked(child, name)
- // For now, child has 0 references, so our caller should call
- // child.checkCachingLocked().
- *ds = appendDentry(*ds, child)
+ appendNewChildDentry(ds, parent, child)
return child, nil
}
@@ -355,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up
// to date.
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
}
d = next
+ if followedSymlink {
+ if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
+ }
}
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -375,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions: fs.renameMu must be locked.
func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
d := rp.Start().Impl().(*dentry)
- if !d.cachedMetadataAuthoritative() {
- // Get updated metadata for rp.Start() as required by fs.stepLocked().
- if err := d.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
+ if err := fs.revalidatePath(ctx, rp, d, ds); err != nil {
+ return nil, err
}
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
}
d = next
+ if followedSymlink {
+ if err := fs.revalidatePath(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
+ }
}
if rp.MustBeDir() && !d.isDir() {
return nil, syserror.ENOTDIR
@@ -408,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return err
- }
- }
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return err
@@ -432,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if parent.isDeleted() {
return syserror.ENOENT
}
+ if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil {
+ return err
+ }
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds)
- switch {
- case err != nil && err != syserror.ENOENT:
- return err
- case child != nil:
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+ // Check for existence only if caching information is available. Otherwise,
+ // don't check for existence just yet. We will check for existence if the
+ // checks for writability fail below. Existence check is done by the creation
+ // RPCs themselves.
+ if child, ok := parent.children[name]; ok && child != nil {
return syserror.EEXIST
}
+ checkExistence := func() error {
+ if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT {
+ return err
+ } else if child != nil {
+ return syserror.EEXIST
+ }
+ return nil
+ }
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
+ // Existence check takes precedence.
+ if existenceErr := checkExistence(); existenceErr != nil {
+ return existenceErr
+ }
return err
}
defer mnt.EndWrite()
if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ // Existence check takes precedence.
+ if existenceErr := checkExistence(); existenceErr != nil {
+ return existenceErr
+ }
return err
}
if !dir && rp.MustBeDir() {
@@ -500,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return err
- }
- }
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return err
@@ -532,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return syserror.EISDIR
}
}
+
vfsObj := rp.VirtualFilesystem()
+ if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil {
+ return err
+ }
+
mntns := vfs.MountNamespaceFromContext(ctx)
defer mntns.DecRef(ctx)
+
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- child, ok := parent.children[name]
- if ok && child == nil {
- return syserror.ENOENT
- }
-
- sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0
- if sticky {
- if !ok {
- // If the sticky bit is set, we need to retrieve the child to determine
- // whether removing it is allowed.
- child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
- if err != nil {
- return err
- }
- } else if child != nil && !child.cachedMetadataAuthoritative() {
- // Make sure the dentry representing the file at name is up to date
- // before examining its metadata.
- child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
- if err != nil {
- return err
- }
+ // Load child if sticky bit is set because we need to determine whether
+ // deletion is allowed.
+ var child *dentry
+ if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 {
+ var ok bool
+ child, ok = parent.children[name]
+ if ok && child == nil {
+ // Hit a negative cached entry, child doesn't exist.
+ return syserror.ENOENT
+ }
+ } else {
+ child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err != nil {
+ return err
}
if err := parent.mayDelete(rp.Credentials(), child); err != nil {
return err
@@ -567,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// If a child dentry exists, prepare to delete it. This should fail if it is
// a mount point. We detect mount points by speculatively calling
- // PrepareDeleteDentry, which fails if child is a mount point. However, we
- // may need to revalidate the file in this case to make sure that it has not
- // been deleted or replaced on the remote fs, in which case the mount point
- // will have disappeared. If calling PrepareDeleteDentry fails again on the
- // up-to-date dentry, we can be sure that it is a mount point.
+ // PrepareDeleteDentry, which fails if child is a mount point.
//
// Also note that if child is nil, then it can't be a mount point.
if child != nil {
@@ -586,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
child.dirMu.Lock()
defer child.dirMu.Unlock()
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- // We can skip revalidation in several cases:
- // - We are not in InteropModeShared
- // - The parent directory is synthetic, in which case the child must also
- // be synthetic
- // - We already updated the child during the sticky bit check above
- if parent.cachedMetadataAuthoritative() || sticky {
- return err
- }
- child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
- if err != nil {
- return err
- }
- if child != nil {
- if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- return err
- }
- }
+ return err
}
}
flags := uint32(0)
@@ -723,6 +649,8 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
}
}
d.IncRef()
+ // Call d.checkCachingLocked() so it can be removed from the cache if needed.
+ ds = appendDentry(ds, d)
return &d.vfsd, nil
}
@@ -732,18 +660,13 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
- }
d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return nil, err
}
d.IncRef()
+ // Call d.checkCachingLocked() so it can be removed from the cache if needed.
+ ds = appendDentry(ds, d)
return &d.vfsd, nil
}
@@ -782,7 +705,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
// If the parent is a setgid directory, use the parent's GID
// rather than the caller's and enable setgid.
kgid := creds.EffectiveKGID
@@ -802,6 +725,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kuid: creds.EffectiveKUID,
kgid: creds.EffectiveKGID,
})
+ *ds = appendDentry(*ds, parent)
}
if fs.opts.interop != InteropModeShared {
parent.incLinks()
@@ -836,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// to creating a synthetic one, i.e. one that is kept entirely in memory.
// Check that we're not overriding an existing file with a synthetic one.
- _, err = fs.stepLocked(ctx, rp, parent, true, ds)
+ _, _, err = fs.stepLocked(ctx, rp, parent, true, ds)
switch {
case err == nil:
// Step succeeded, another file exists.
@@ -855,6 +779,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid: creds.EffectiveKGID,
endpoint: opts.Endpoint,
})
+ *ds = appendDentry(*ds, parent)
return nil
case linux.S_IFIFO:
parent.createSyntheticChildLocked(&createSyntheticOpts{
@@ -864,6 +789,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid: creds.EffectiveKGID,
pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
+ *ds = appendDentry(*ds, parent)
return nil
}
// Retain error from gofer if synthetic file cannot be created internally.
@@ -895,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
defer unlock()
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by fs.stepLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
- }
if rp.Done() {
// Reject attempts to open mount root directory with O_CREAT.
if mayCreate && rp.MustBeDir() {
@@ -909,9 +829,17 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
+ if !start.cachedMetadataAuthoritative() {
+ // Refresh dentry's attributes before opening.
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
start.IncRef()
defer start.DecRef(ctx)
unlock()
+ // start is intentionally not added to ds (which would remove it from the
+ // cache) because doing so regresses performance in practice.
return start.open(ctx, rp, &opts)
}
@@ -928,9 +856,12 @@ afterTrailingSymlink:
if mayCreate && rp.MustBeDir() {
return nil, syserror.EISDIR
}
+ if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil {
+ return nil, err
+ }
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
if parent.isSynthetic() {
parent.dirMu.Unlock()
@@ -965,6 +896,8 @@ afterTrailingSymlink:
child.IncRef()
defer child.DecRef(ctx)
unlock()
+ // child is intentionally not added to ds (which would remove it from the
+ // cache) because doing so regresses performance in practice.
return child.open(ctx, rp, &opts)
}
@@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
- *ds = appendDentry(*ds, child)
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
@@ -1212,6 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
// Insert the dentry into the tree.
d.cacheNewChildLocked(child, name)
+ appendNewChildDentry(ds, d, child)
if d.cachedMetadataAuthoritative() {
d.touchCMtime()
d.dirents = nil
@@ -1296,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
+
vfsObj := rp.VirtualFilesystem()
+ if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil {
+ return err
+ }
+ if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil {
+ return err
+ }
+
// We need a dentry representing the renamed file since, if it's a
// directory, we need to check for write permission on it.
oldParent.dirMu.Lock()
defer oldParent.dirMu.Unlock()
- renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds)
+ renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
if err != nil {
return err
}
- if renamed == nil {
- return syserror.ENOENT
- }
if err := oldParent.mayDelete(creds, renamed); err != nil {
return err
}
@@ -1336,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.isDeleted() {
return syserror.ENOENT
}
- replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
- if err != nil {
+ replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil && err != syserror.ENOENT {
return err
}
var replacedVFSD *vfs.Dentry
@@ -1401,8 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// parent isn't actually changing.
if oldParent != newParent {
oldParent.decRefNoCaching()
- ds = appendDentry(ds, oldParent)
newParent.IncRef()
+ ds = appendDentry(ds, newParent)
+ ds = appendDentry(ds, oldParent)
if renamed.isSynthetic() {
oldParent.syntheticChildren--
newParent.syntheticChildren++
@@ -1546,6 +1485,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
if d.isSocket() {
if !d.isSynthetic() {
d.IncRef()
+ ds = appendDentry(ds, d)
return &endpoint{
dentry: d,
path: opts.Addr,
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index a0c05231a..21692d2ac 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -18,21 +18,23 @@
// Lock order:
// regularFileFD/directoryFD.mu
// filesystem.renameMu
-// dentry.dirMu
-// filesystem.syncMu
-// dentry.metadataMu
-// *** "memmap.Mappable locks" below this point
-// dentry.mapsMu
-// *** "memmap.Mappable locks taken by Translate" below this point
-// dentry.handleMu
-// dentry.dataMu
-// filesystem.inoMu
+// dentry.cachingMu
+// filesystem.cacheMu
+// dentry.dirMu
+// filesystem.syncMu
+// dentry.metadataMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.handleMu
+// dentry.dataMu
+// filesystem.inoMu
// specialFileFD.mu
// specialFileFD.bufMu
//
-// Locking dentry.dirMu in multiple dentries requires that either ancestor
-// dentries are locked before descendant dentries, or that filesystem.renameMu
-// is locked for writing.
+// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that
+// either ancestor dentries are locked before descendant dentries, or that
+// filesystem.renameMu is locked for writing.
package gofer
import (
@@ -140,7 +142,8 @@ type filesystem struct {
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
// cachedDentriesLen is the number of dentries in cachedDentries. These fields
- // are protected by renameMu.
+ // are protected by cacheMu.
+ cacheMu sync.Mutex `state:"nosave"`
cachedDentries dentryList
cachedDentriesLen uint64
@@ -620,11 +623,11 @@ func (fs *filesystem) Release(ctx context.Context) {
// the reference count on every synthetic dentry. Synthetic dentries have one
// reference for existence that should be dropped during filesystem.Release.
//
-// Precondition: d.fs.renameMu is locked.
+// Precondition: d.fs.renameMu is locked for writing.
func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
if d.isSynthetic() {
d.decRefNoCaching()
- d.checkCachingLocked(ctx)
+ d.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
if d.isDir() {
var children []*dentry
@@ -682,9 +685,13 @@ type dentry struct {
// deleted. deleted is accessed using atomic memory operations.
deleted uint32
+ // cachingMu is used to synchronize concurrent dentry caching attempts on
+ // this dentry.
+ cachingMu sync.Mutex `state:"nosave"`
+
// If cached is true, dentryEntry links dentry into
// filesystem.cachedDentries. cached and dentryEntry are protected by
- // filesystem.renameMu.
+ // cachingMu.
cached bool
dentryEntry
@@ -980,36 +987,63 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
// Preconditions: !d.isSynthetic().
+// Preconditions: d.metadataMu is locked.
+func (d *dentry) refreshSizeLocked(ctx context.Context) error {
+ d.handleMu.RLock()
+
+ if d.writeFD < 0 {
+ d.handleMu.RUnlock()
+ // Ask the gofer if we don't have a host FD.
+ return d.updateFromGetattrLocked(ctx)
+ }
+
+ var stat unix.Statx_t
+ err := unix.Statx(int(d.writeFD), "", unix.AT_EMPTY_PATH, unix.STATX_SIZE, &stat)
+ d.handleMu.RUnlock() // must be released before updateSizeLocked()
+ if err != nil {
+ return err
+ }
+ d.updateSizeLocked(stat.Size)
+ return nil
+}
+
+// Preconditions: !d.isSynthetic().
func (d *dentry) updateFromGetattr(ctx context.Context) error {
- // Use d.readFile or d.writeFile, which represent 9P fids that have been
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ return d.updateFromGetattrLocked(ctx)
+}
+
+// Preconditions:
+// * !d.isSynthetic().
+// * d.metadataMu is locked.
+func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
+ // Use d.readFile or d.writeFile, which represent 9P FIDs that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
// This may be significantly more efficient in some implementations. Prefer
// d.writeFile over d.readFile since some filesystem implementations may
// update a writable handle's metadata after writes to that handle, without
// making metadata updates immediately visible to read-only handles
// representing the same file.
- var (
- file p9file
- handleMuRLocked bool
- )
- // d.metadataMu must be locked *before* we getAttr so that we do not end up
- // updating stale attributes in d.updateFromP9AttrsLocked().
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
d.handleMu.RLock()
- if !d.writeFile.isNil() {
+ handleMuRLocked := true
+ var file p9file
+ switch {
+ case !d.writeFile.isNil():
file = d.writeFile
- handleMuRLocked = true
- } else if !d.readFile.isNil() {
+ case !d.readFile.isNil():
file = d.readFile
- handleMuRLocked = true
- } else {
+ default:
file = d.file
d.handleMu.RUnlock()
+ handleMuRLocked = false
}
+
_, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
if handleMuRLocked {
- d.handleMu.RUnlock()
+ d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked()
}
if err != nil {
return err
@@ -1104,24 +1138,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
defer d.metadataMu.Unlock()
// As with Linux, if the UID, GID, or file size is changing, we have to
- // clear permission bits. Note that when set, clearSGID causes
- // permissions to be updated, but does not modify stat.Mask, as
- // modification would cause an extra inotify flag to be set.
- clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) ||
- stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) ||
+ // clear permission bits. Note that when set, clearSGID may cause
+ // permissions to be updated.
+ clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) ||
+ (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) ||
stat.Mask&linux.STATX_SIZE != 0
if clearSGID {
if stat.Mask&linux.STATX_MODE != 0 {
stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode)))
} else {
- stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode)))
+ oldMode := atomic.LoadUint32(&d.mode)
+ if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode {
+ stat.Mode = uint16(updatedMode)
+ stat.Mask |= linux.STATX_MODE
+ }
}
}
if !d.isSynthetic() {
if stat.Mask != 0 {
if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID,
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
GID: stat.Mask&linux.STATX_GID != 0,
Size: stat.Mask&linux.STATX_SIZE != 0,
@@ -1156,7 +1193,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 || clearSGID {
+ if stat.Mask&linux.STATX_MODE != 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
if stat.Mask&linux.STATX_UID != 0 {
@@ -1312,9 +1349,7 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
if d.decRefNoCaching() == 0 {
- d.fs.renameMu.Lock()
- d.checkCachingLocked(ctx)
- d.fs.renameMu.Unlock()
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
}
@@ -1374,15 +1409,16 @@ func (d *dentry) Watches() *vfs.Watches {
//
// If no watches are left on this dentry and it has no references, cache it.
func (d *dentry) OnZeroWatches(ctx context.Context) {
- if atomic.LoadInt64(&d.refs) == 0 {
- d.fs.renameMu.Lock()
- d.checkCachingLocked(ctx)
- d.fs.renameMu.Unlock()
- }
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
-// checkCachingLocked should be called after d's reference count becomes 0 or it
-// becomes disowned.
+// checkCachingLocked should be called after d's reference count becomes 0 or
+// it becomes disowned.
+//
+// For performance, checkCachingLocked can also be called after d's reference
+// count becomes non-zero, so that d can be removed from the LRU cache. This
+// may help in reducing the size of the cache and hence reduce evictions. Note
+// that this is not necessary for correctness.
//
// It may be called on a destroyed dentry. For example,
// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
@@ -1390,33 +1426,46 @@ func (d *dentry) OnZeroWatches(ctx context.Context) {
// operation. One of the calls may destroy the dentry, so subsequent calls will
// do nothing.
//
-// Preconditions: d.fs.renameMu must be locked for writing; it may be
-// temporarily unlocked.
-func (d *dentry) checkCachingLocked(ctx context.Context) {
- // Dentries with a non-zero reference count must be retained. (The only way
- // to obtain a reference on a dentry with zero references is via path
- // resolution, which requires renameMu, so if d.refs is zero then it will
- // remain zero while we hold renameMu for writing.)
+// Preconditions: d.fs.renameMu must be locked for writing if
+// renameMuWriteLocked is true; it may be temporarily unlocked.
+func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) {
+ d.cachingMu.Lock()
refs := atomic.LoadInt64(&d.refs)
if refs == -1 {
// Dentry has already been destroyed.
+ d.cachingMu.Unlock()
return
}
if refs > 0 {
- // This isn't strictly necessary (fs.cachedDentries is permitted to
- // contain dentries with non-zero refs, which are skipped by
- // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but
- // since we are already holding fs.renameMu for writing we may as well.
+ // fs.cachedDentries is permitted to contain dentries with non-zero refs,
+ // which are skipped by fs.evictCachedDentryLocked() upon reaching the end
+ // of the LRU. But it is still beneficial to remove d from the cache as we
+ // are already holding d.cachingMu. Keeping a cleaner cache also reduces
+ // the number of evictions (which is expensive as it acquires fs.renameMu).
d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
return
}
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
+ d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by d.destroyLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ // Now that renameMu is locked for writing, no more refs can be taken on
+ // d because path resolution requires renameMu for reading at least.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ // Destroy d only if its ref is still 0. If not, either someone took a
+ // ref on it or it got destroyed before fs.renameMu could be acquired.
+ return
+ }
+ }
if d.isDeleted() {
d.watches.HandleDeletion(ctx)
}
- d.removeFromCacheLocked()
d.destroyLocked(ctx)
return
}
@@ -1426,24 +1475,36 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
// d.watches cannot concurrently transition from zero to non-zero, because
// adding a watch requires holding a reference on d.
if d.watches.Size() > 0 {
- // As in the refs > 0 case, this is not strictly necessary.
+ // As in the refs > 0 case, removing d is beneficial.
d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
return
}
if atomic.LoadInt32(&d.fs.released) != 0 {
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as
+ // needed by d.destroyLocked() later.
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
if d.parent != nil {
d.parent.dirMu.Lock()
delete(d.parent.children, d.name)
d.parent.dirMu.Unlock()
}
d.destroyLocked(ctx)
+ return
}
+ d.fs.cacheMu.Lock()
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentries.PushFront(d)
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
return
}
// Cache the dentry, then evict the least recently used cached dentry if
@@ -1451,18 +1512,28 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
d.fs.cachedDentries.PushFront(d)
d.fs.cachedDentriesLen++
d.cached = true
- if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
+ shouldEvict := d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+
+ if shouldEvict {
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by
+ // d.evictCachedDentryLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
d.fs.evictCachedDentryLocked(ctx)
- // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
- // back down to fs.opts.maxCachedDentries, so we don't loop.
}
}
-// Preconditions: d.fs.renameMu must be locked for writing.
+// Preconditions: d.cachingMu must be locked.
func (d *dentry) removeFromCacheLocked() {
if d.cached {
+ d.fs.cacheMu.Lock()
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
+ d.fs.cacheMu.Unlock()
d.cached = false
}
}
@@ -1477,28 +1548,43 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
// Preconditions:
// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
-// * fs.cachedDentriesLen != 0.
func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ fs.cacheMu.Lock()
victim := fs.cachedDentries.Back()
+ fs.cacheMu.Unlock()
+ if victim == nil {
+ // fs.cachedDentries may have become empty between when it was checked and
+ // when we locked fs.cacheMu.
+ return
+ }
+
+ victim.cachingMu.Lock()
victim.removeFromCacheLocked()
// victim.refs or victim.watches.Size() may have become non-zero from an
// earlier path resolution since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 {
- if victim.parent != nil {
- victim.parent.dirMu.Lock()
- if !victim.vfsd.IsDead() {
- // Note that victim can't be a mount point (in any mount
- // namespace), since VFS holds references on mount points.
- fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
- delete(victim.parent.children, victim.name)
- // We're only deleting the dentry, not the file it
- // represents, so we don't need to update
- // victimParent.dirents etc.
- }
- victim.parent.dirMu.Unlock()
+ if atomic.LoadInt64(&victim.refs) != 0 || victim.watches.Size() != 0 {
+ victim.cachingMu.Unlock()
+ return
+ }
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
}
- victim.destroyLocked(ctx)
+ victim.parent.dirMu.Unlock()
}
+ // Safe to unlock cachingMu now that victim.vfsd.IsDead(). Henceforth any
+ // concurrent caching attempts on victim will attempt to destroy it and so
+ // will try to acquire fs.renameMu (which we have already acquired). Hence,
+ // fs.renameMu will synchronize the destroy attempts.
+ victim.cachingMu.Unlock()
+ victim.destroyLocked(ctx)
}
// destroyLocked destroys the dentry.
@@ -1584,7 +1670,7 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// Drop the reference held by d on its parent without recursively locking
// d.fs.renameMu.
if d.parent != nil && d.parent.decRefNoCaching() == 0 {
- d.parent.checkCachingLocked(ctx)
+ d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
refsvfs2.Unregister(d)
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 76f08e252..806392d50 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -55,7 +55,7 @@ func TestDestroyIdempotent(t *testing.T) {
fs.renameMu.Lock()
defer fs.renameMu.Unlock()
- child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
if got := atomic.LoadInt64(&child.refs); got != -1 {
t.Fatalf("child.refs=%d, want: -1", got)
}
@@ -63,6 +63,6 @@ func TestDestroyIdempotent(t *testing.T) {
if got := atomic.LoadInt64(&parent.refs); got != -1 {
t.Fatalf("parent.refs=%d, want: -1", got)
}
- child.checkCachingLocked(ctx)
- child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 21b4a96fe..b0a429d42 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err
ctx.UninterruptibleSleepFinish(false)
return fdobj, err
}
+
+func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ stats, err := f.file.MultiGetAttr(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return stats, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 47563538c..f0e7bbaf7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -204,18 +204,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
d := fd.dentry()
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
// If the fd was opened with O_APPEND, make sure the file size is updated.
// There is a possible race here if size is modified externally after
// metadata cache is updated.
if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
- if err := d.updateFromGetattr(ctx); err != nil {
+ if err := d.refreshSizeLocked(ctx); err != nil {
return 0, offset, err
}
}
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
-
// Set offset to file size if the fd was opened with O_APPEND.
if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
// Holding d.metadataMu is sufficient for reading d.size.
@@ -701,6 +702,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
}
// After this point, d may be used as a memmap.Mappable.
d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ opts.SentryOwnedContent = d.fs.opts.forcePageCache
return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
}
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
new file mode 100644
index 000000000..8f81f0822
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -0,0 +1,386 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+type errPartialRevalidation struct{}
+
+// Error implements error.Error.
+func (errPartialRevalidation) Error() string {
+ return "partial revalidation"
+}
+
+type errRevalidationStepDone struct{}
+
+// Error implements error.Error.
+func (errRevalidationStepDone) Error() string {
+ return "stop revalidation"
+}
+
+// revalidatePath checks cached dentries for external modification. File
+// attributes are refreshed and cache is invalidated in case the dentry has been
+// deleted, or a new file/directory created in its place.
+//
+// Revalidation stops at symlinks and mount points. The caller is responsible
+// for revalidating again after symlinks are resolved and after changing to
+// different mounts.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error {
+ // Revalidation is done even if start is synthetic in case the path is
+ // something like: ../non_synthetic_file.
+ if fs.opts.interop != InteropModeShared {
+ return nil
+ }
+
+ // Copy resolving path to walk the path for revalidation.
+ rp := rpOrig.Copy()
+ err := fs.revalidate(ctx, rp, start, rp.Done, ds)
+ rp.Release(ctx)
+ return err
+}
+
+// revalidateParentDir does the same as revalidatePath, but stops at the parent.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error {
+ // Revalidation is done even if start is synthetic in case the path is
+ // something like: ../non_synthetic_file and parent is non synthetic.
+ if fs.opts.interop != InteropModeShared {
+ return nil
+ }
+
+ // Copy resolving path to walk the path for revalidation.
+ rp := rpOrig.Copy()
+ err := fs.revalidate(ctx, rp, start, rp.Final, ds)
+ rp.Release(ctx)
+ return err
+}
+
+// revalidateOne does the same as revalidatePath, but checks a single dentry.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error {
+ // Skip revalidation for interop mode different than InteropModeShared or
+ // if the parent is synthetic (child must be synthetic too, but it cannot be
+ // replaced without first replacing the parent).
+ if parent.cachedMetadataAuthoritative() {
+ return nil
+ }
+
+ parent.dirMu.Lock()
+ child, ok := parent.children[name]
+ parent.dirMu.Unlock()
+ if !ok {
+ return nil
+ }
+
+ state := makeRevalidateState(parent)
+ defer state.release()
+
+ state.add(name, child)
+ return fs.revalidateHelper(ctx, vfsObj, state, ds)
+}
+
+// revalidate revalidates path components in rp until done returns true, or
+// until a mount point or symlink is reached. It may send multiple MultiGetAttr
+// calls to the gofer to handle ".." in the path.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * InteropModeShared is in effect.
+func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error {
+ state := makeRevalidateState(start)
+ defer state.release()
+
+ // Skip synthetic dentries because the start dentry cannot be replaced in case
+ // it has been created in the remote file system.
+ if !start.isSynthetic() {
+ state.add("", start)
+ }
+
+done:
+ for cur := start; !done(); {
+ var err error
+ cur, err = fs.revalidateStep(ctx, rp, cur, state)
+ if err != nil {
+ switch err.(type) {
+ case errPartialRevalidation:
+ if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil {
+ return err
+ }
+
+ // Reset state to release any remaining locks and restart from where
+ // stepping stopped.
+ state.reset()
+ state.start = cur
+
+ // Skip synthetic dentries because the start dentry cannot be replaced in
+ // case it has been created in the remote file system.
+ if !cur.isSynthetic() {
+ state.add("", cur)
+ }
+
+ case errRevalidationStepDone:
+ break done
+
+ default:
+ return err
+ }
+ }
+ }
+ return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds)
+}
+
+// revalidateStep walks one element of the path and updates revalidationState
+// with the dentry if needed. It may also stop the stepping or ask for a
+// partial revalidation. Partial revalidation requires the caller to revalidate
+// the current revalidationState, release all locks, and resume stepping.
+// In case a symlink is hit, revalidation stops and the caller is responsible
+// for calling revalidate again after the symlink is resolved. Revalidation may
+// also stop for other reasons, like hitting a child not in the cache.
+//
+// Returns:
+// * (dentry, nil): step worked, continue stepping.`
+// * (dentry, errPartialRevalidation): revalidation should be done with the
+// state gathered so far. Then continue stepping with the remainder of the
+// path, starting at `dentry`.
+// * (nil, errRevalidationStepDone): revalidation doesn't need to step any
+// further. It hit a symlink, a mount point, or an uncached dentry.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * !rp.Done().
+// * InteropModeShared is in effect (assumes no negative dentries).
+func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) {
+ switch name := rp.Component(); name {
+ case ".":
+ // Do nothing.
+
+ case "..":
+ // Partial revalidation is required when ".." is hit because metadata locks
+ // can only be acquired from parent to child to avoid deadlocks.
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
+ return nil, errRevalidationStepDone{}
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, errPartialRevalidation{}
+ }
+ // We must assume that d.parent is correct, because if d has been moved
+ // elsewhere in the remote filesystem so that its parent has changed,
+ // we have no way of determining its new parent's location in the
+ // filesystem.
+ //
+ // Call rp.CheckMount() before updating d.parent's metadata, since if
+ // we traverse to another mount then d.parent's metadata is irrelevant.
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, errRevalidationStepDone{}
+ }
+ rp.Advance()
+ return d.parent, errPartialRevalidation{}
+
+ default:
+ d.dirMu.Lock()
+ child, ok := d.children[name]
+ d.dirMu.Unlock()
+ if !ok {
+ // child is not cached, no need to validate any further.
+ return nil, errRevalidationStepDone{}
+ }
+
+ state.add(name, child)
+
+ // Symlink must be resolved before continuing with revalidation.
+ if child.isSymlink() {
+ return nil, errRevalidationStepDone{}
+ }
+
+ d = child
+ }
+
+ rp.Advance()
+ return d, nil
+}
+
+// revalidateHelper calls the gofer to stat all dentries in `state`. It will
+// update or invalidate dentries in the cache based on the result.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * InteropModeShared is in effect.
+func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error {
+ if len(state.names) == 0 {
+ return nil
+ }
+ // Lock metadata on all dentries *before* getting attributes for them.
+ state.lockAllMetadata()
+ stats, err := state.start.file.multiGetAttr(ctx, state.names)
+ if err != nil {
+ return err
+ }
+
+ i := -1
+ for d := state.popFront(); d != nil; d = state.popFront() {
+ i++
+ found := i < len(stats)
+ if i == 0 && len(state.names[0]) == 0 {
+ if found && !d.isSynthetic() {
+ // First dentry is where the search is starting, just update attributes
+ // since it cannot be replaced.
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ }
+ d.metadataMu.Unlock()
+ continue
+ }
+
+ // Note that synthetic dentries will always fails the comparison check
+ // below.
+ if !found || d.qidPath != stats[i].QID.Path {
+ d.metadataMu.Unlock()
+ if !found && d.isSynthetic() {
+ // We have a synthetic file, and no remote file has arisen to replace
+ // it.
+ return nil
+ }
+ // The file at this path has changed or no longer exists. Mark the
+ // dentry invalidated, and re-evaluate its caching status (i.e. if it
+ // has 0 references, drop it). The dentry will be reloaded next time it's
+ // accessed.
+ vfsObj.InvalidateDentry(ctx, &d.vfsd)
+
+ name := state.names[i]
+ d.parent.dirMu.Lock()
+
+ if d.isSynthetic() {
+ // Normally we don't mark invalidated dentries as deleted since
+ // they may still exist (but at a different path), and also for
+ // consistency with Linux. However, synthetic files are guaranteed
+ // to become unreachable if their dentries are invalidated, so
+ // treat their invalidation as deletion.
+ d.setDeleted()
+ d.decRefNoCaching()
+ *ds = appendDentry(*ds, d)
+
+ d.parent.syntheticChildren--
+ d.parent.dirents = nil
+ }
+
+ // Since the dirMu was released and reacquired, re-check that the
+ // parent's child with this name is still the same. Do not touch it if
+ // it has been replaced with a different one.
+ if child := d.parent.children[name]; child == d {
+ // Invalidate dentry so it gets reloaded next time it's accessed.
+ delete(d.parent.children, name)
+ }
+ d.parent.dirMu.Unlock()
+
+ return nil
+ }
+
+ // The file at this path hasn't changed. Just update cached metadata.
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ d.metadataMu.Unlock()
+ }
+
+ return nil
+}
+
+// revalidateStatePool caches revalidateState instances to save array
+// allocations for dentries and names.
+var revalidateStatePool = sync.Pool{
+ New: func() interface{} {
+ return &revalidateState{}
+ },
+}
+
+// revalidateState keeps state related to a revalidation request. It keeps track
+// of {name, dentry} list being revalidated, as well as metadata locks on the
+// dentries. The list must be in ancestry order, in other words `n` must be
+// `n-1` child.
+type revalidateState struct {
+ // start is the dentry where to start the attributes search.
+ start *dentry
+
+ // List of names of entries to refresh attributes. Names length must be the
+ // same as detries length. They are kept in separate slices because names is
+ // used to call File.MultiGetAttr().
+ names []string
+
+ // dentries is the list of dentries that correspond to the names above.
+ // dentry.metadataMu is acquired as each dentry is added to this list.
+ dentries []*dentry
+
+ // locked indicates if metadata lock has been acquired on dentries.
+ locked bool
+}
+
+func makeRevalidateState(start *dentry) *revalidateState {
+ r := revalidateStatePool.Get().(*revalidateState)
+ r.start = start
+ return r
+}
+
+// release must be called after the caller is done with this object. It releases
+// all metadata locks and resources.
+func (r *revalidateState) release() {
+ r.reset()
+ revalidateStatePool.Put(r)
+}
+
+// Preconditions:
+// * d is a descendant of all dentries in r.dentries.
+func (r *revalidateState) add(name string, d *dentry) {
+ r.names = append(r.names, name)
+ r.dentries = append(r.dentries, d)
+}
+
+func (r *revalidateState) lockAllMetadata() {
+ for _, d := range r.dentries {
+ d.metadataMu.Lock()
+ }
+ r.locked = true
+}
+
+func (r *revalidateState) popFront() *dentry {
+ if len(r.dentries) == 0 {
+ return nil
+ }
+ d := r.dentries[0]
+ r.dentries = r.dentries[1:]
+ return d
+}
+
+// reset releases all metadata locks and resets all fields to allow this
+// instance to be reused.
+func (r *revalidateState) reset() {
+ if r.locked {
+ // Unlock any remaining dentries.
+ for _, d := range r.dentries {
+ d.metadataMu.Unlock()
+ }
+ r.locked = false
+ }
+ r.start = nil
+ r.names = r.names[:0]
+ r.dentries = r.dentries[:0]
+}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 3b90375b6..a81f550b1 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -460,6 +460,9 @@ func (i *inode) DecRef(ctx context.Context) {
if err := unix.Close(i.hostFD); err != nil {
log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
}
+ // We can't rely on fdnotifier when closing the fd, because the event may race
+ // with fdnotifier.RemoveFD. Instead, notify the queue explicitly.
+ i.queue.Notify(waiter.EventHUp | waiter.ReadableEvents | waiter.WritableEvents)
})
}
diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go
index 31301c715..c502d8e99 100644
--- a/pkg/sentry/fsimpl/host/save_restore.go
+++ b/pkg/sentry/fsimpl/host/save_restore.go
@@ -68,3 +68,10 @@ func (i *inode) afterLoad() {
}
}
}
+
+// afterLoad is invoked by stateify.
+func (c *ConnectedEndpoint) afterLoad() {
+ if err := c.initFromOptions(); err != nil {
+ panic(fmt.Sprintf("initFromOptions failed: %v", err))
+ }
+}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 60e237ac7..ca85f5601 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -39,7 +39,7 @@ import (
func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) {
// Set up an external transport.Endpoint using the host fd.
addr := fmt.Sprintf("hostfd:[%d]", hostFD)
- e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */)
+ e, err := NewConnectedEndpoint(hostFD, addr)
if err != nil {
return nil, err.ToError()
}
@@ -86,7 +86,10 @@ type ConnectedEndpoint struct {
// for restoring them.
func (c *ConnectedEndpoint) init() *syserr.Error {
c.InitRefs()
+ return c.initFromOptions()
+}
+func (c *ConnectedEndpoint) initFromOptions() *syserr.Error {
family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
return syserr.FromError(err)
@@ -123,7 +126,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
// The caller is responsible for calling Init(). Additionaly, Release needs to
// be called twice because ConnectedEndpoint is both a transport.Receiver and
// transport.ConnectedEndpoint.
-func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) {
e := ConnectedEndpoint{
fd: hostFD,
addr: addr,
@@ -330,8 +333,16 @@ func (c *ConnectedEndpoint) CloseUnread() {}
// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
- // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
- // sockets.
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix
+ // domain sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
+// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
+func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_RCVBUF for host backed unix
+ // domain sockets. Receive buffer does not have any effect for unix
+ // sockets and we claim to be the same as send buffer.
return atomic.LoadInt64(&c.sndbuf)
}
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 65054b0ea..84b1c3745 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -25,8 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// DynamicBytesFile implements kernfs.Inode and represents a read-only
-// file whose contents are backed by a vfs.DynamicBytesSource.
+// DynamicBytesFile implements kernfs.Inode and represents a read-only file
+// whose contents are backed by a vfs.DynamicBytesSource. If data additionally
+// implements vfs.WritableDynamicBytesSource, the file also supports dispatching
+// writes to the implementer, but note that this will not update the source data.
//
// Must be instantiated with NewDynamicBytesFile or initialized with Init
// before first use.
@@ -40,7 +42,9 @@ type DynamicBytesFile struct {
InodeNotSymlink
locks vfs.FileLocks
- data vfs.DynamicBytesSource
+ // data can additionally implement vfs.WritableDynamicBytesSource to support
+ // writes.
+ data vfs.DynamicBytesSource
}
var _ Inode = (*DynamicBytesFile)(nil)
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index badca4d9f..f50b0fb08 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -612,16 +612,24 @@ afterTrailingSymlink:
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- fs.mu.RLock()
defer fs.processDeferredDecRefs(ctx)
- defer fs.mu.RUnlock()
+
+ fs.mu.RLock()
d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
+ fs.mu.RUnlock()
return "", err
}
if !d.isSymlink() {
+ fs.mu.RUnlock()
return "", syserror.EINVAL
}
+
+ // Inode.Readlink() cannot be called holding fs locks.
+ d.IncRef()
+ defer d.DecRef(ctx)
+ fs.mu.RUnlock()
+
return d.inode.Readlink(ctx, rp.Mount())
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 565d723f0..6f699c9cd 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode {
return d.inode
}
+// FSLocalPath returns an absolute path to d, relative to the root of its
+// filesystem.
+func (d *Dentry) FSLocalPath() string {
+ var b fspath.Builder
+ _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ b.PrependByte('/')
+ return b.String()
+}
+
// The Inode interface maps filesystem-level operations that operate on paths to
// equivalent operations on specific filesystem nodes.
//
@@ -524,6 +534,9 @@ func (d *Dentry) Inode() Inode {
// - Checking that dentries passed to methods are of the appropriate file type.
// - Checking permissions.
//
+// Inode functions may be called holding filesystem wide locks and are not
+// allowed to call vfs functions that may reenter, unless otherwise noted.
+//
// Specific responsibilities of implementations are documented below.
type Inode interface {
// Methods related to reference counting. A generic implementation is
@@ -670,6 +683,9 @@ type inodeDirectory interface {
type inodeSymlink interface {
// Readlink returns the target of a symbolic link. If an inode is not a
// symlink, the implementation should return EINVAL.
+ //
+ // Readlink is called with no kernfs locks held, so it may reenter if needed
+ // to resolve symlink targets.
Readlink(ctx context.Context, mnt *vfs.Mount) (string, error)
// Getlink returns the target of a symbolic link, as used by path
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 254a8b062..ce8f55b1f 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
- var cgroups map[string]string
+ var fakeCgroupControllers map[string]string
if opts.InternalData != nil {
data := opts.InternalData.(*InternalData)
- cgroups = data.Cgroups
+ fakeCgroupControllers = data.Cgroups
}
- inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers)
var dentry kernfs.Dentry
dentry.InitRoot(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index fea138f93..d05cc1508 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -47,7 +47,7 @@ type taskInode struct {
var _ kernfs.Inode = (*taskInode)(nil)
-func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) {
+func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) {
if task.ExitState() == kernel.TaskExitDead {
return nil, syserror.ESRCH
}
@@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns
"uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
- contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers)
+ contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers)
}
- if len(cgroupControllers) > 0 {
- contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ if len(fakeCgroupControllers) > 0 {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers))
+ } else {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task})
}
taskInode := &taskInode{task: task}
@@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
return &ioData{ioUsage: t}
}
-// newCgroupData creates inode that shows cgroup information.
-// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
-// member, there is one entry containing three colon-separated fields:
-// hierarchy-ID:controller-list:cgroup-path"
-func newCgroupData(controllers map[string]string) dynamicInode {
+// newFakeCgroupData creates an inode that shows fake cgroup
+// information passed in as mount options. From man 7 cgroups: "For
+// each cgroup hierarchy of which the process is a member, there is
+// one entry containing three colon-separated fields:
+// hierarchy-ID:controller-list:cgroup-path"
+//
+// TODO(b/182488796): Remove once all users adopt cgroupfs.
+func newFakeCgroupData(controllers map[string]string) dynamicInode {
var buf bytes.Buffer
// The hierarchy ids must be positive integers (for cgroup v1), but the
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 02bf74dbc..4718fac7a 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error)
defer file.DecRef(ctx)
root := vfs.RootFromContext(ctx)
defer root.DecRef(ctx)
+
+ // Note: it's safe to reenter kernfs from Readlink if needed to resolve path.
return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 85909d551..b294dfd6a 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -1100,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
func (fd *namespaceFD) Release(ctx context.Context) {
fd.inode.DecRef(ctx)
}
+
+// taskCgroupData generates data for /proc/[pid]/cgroup.
+//
+// +stateify savable
+type taskCgroupData struct {
+ dynamicBytesFileSetAttr
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*taskCgroupData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // When a task is existing on Linux, a task's cgroup set is cleared and
+ // reset to the initial cgroup set, which is essentially the set of root
+ // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable
+ // on Linux throughout a task's lifetime.
+ //
+ // The sentry removes tasks from cgroups during the exit process, but
+ // doesn't move them into an initial cgroup set, so partway through task
+ // exit this file show a task is in no cgroups, which is incorrect. Instead,
+ // once a task has left its cgroups, we return an error.
+ if d.task.ExitState() >= kernel.TaskExitInitiated {
+ return syserror.ESRCH
+ }
+
+ d.task.GenerateProcTaskCgroup(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index fdc580610..cf905fae4 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -54,17 +54,18 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- // cgroupControllers is a map of controller name to directory in the
+ // fakeCgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
// in /proc/pid/cgroup if not nil.
- cgroupControllers map[string]string
+ fakeCgroupControllers map[string]string
}
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
+ "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}),
"cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
"filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}),
"loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}),
@@ -76,11 +77,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns
"uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
"version": fs.newInode(ctx, root, 0444, &versionData{}),
}
+ // If fakeCgroupControllers are provided, don't create a cgroupfs backed
+ // /proc/cgroup as it will not match the fake controllers.
+ if len(fakeCgroupControllers) == 0 {
+ contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{})
+ }
inode := &tasksInode{
- pidns: pidns,
- fs: fs,
- cgroupControllers: cgroupControllers,
+ pidns: pidns,
+ fs: fs,
+ fakeCgroupControllers: fakeCgroupControllers,
}
inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.InitRefs()
@@ -118,7 +124,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers)
+ return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers)
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index f0029cda6..045ed7a2d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- k := kernel.KernelFromContext(ctx)
- init := k.GlobalInit()
- if init == nil {
- // Attempted to read before the init Task is created. This can
- // only occur during startup, which should never need to read
- // this file.
- panic("Attempted to read version before initial Task is available")
- }
-
// /proc/version takes the form:
//
// "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
@@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// FIXME(mpratt): Using Version from the init task SyscallTable
// disregards the different version a task may have (e.g., in a uts
// namespace).
- ver := init.Leader().SyscallTable().Version
+ ver := kernelVersion(ctx)
fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
return nil
}
@@ -384,3 +375,47 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error
k.VFS().GenerateProcFilesystems(buf)
return nil
}
+
+// cgroupsData backs /proc/cgroups.
+//
+// +stateify savable
+type cgroupsData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*cgroupsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ r := kernel.KernelFromContext(ctx).CgroupRegistry()
+ r.GenerateProcCgroups(buf)
+ return nil
+}
+
+// cmdLineData backs /proc/cmdline.
+//
+// +stateify savable
+type cmdLineData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*cmdLineData)(nil)
+
+// Generate implements vfs.DynamicByteSource.Generate.
+func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release)
+ return nil
+}
+
+// kernelVersion returns the kernel version.
+func kernelVersion(ctx context.Context) kernel.Version {
+ k := kernel.KernelFromContext(ctx)
+ init := k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+ return init.Leader().SyscallTable().Version
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index d6f076cd6..e534fbca8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -47,6 +47,7 @@ var (
var (
tasksStaticFiles = map[string]testutil.DirentType{
+ "cmdline": linux.DT_REG,
"cpuinfo": linux.DT_REG,
"filesystems": linux.DT_REG,
"loadavg": linux.DT_REG,
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1d9280dae..14eb10dcd 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -122,11 +122,11 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs
}
func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
- // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs
- // should be mounted at debug/, but for our purposes, it is sufficient to
- // keep it in sys.
+ // Set up /sys/kernel/debug/kcov. Technically, debugfs should be
+ // mounted at debug/, but for our purposes, it is sufficient to keep it
+ // in sys.
var children map[string]kernfs.Inode
- if coverage.KcovAvailable() {
+ if coverage.KcovSupported() {
log.Debugf("Set up /sys/kernel/debug/kcov")
children = map[string]kernfs.Inode{
"debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index b3f9d1010..c766164c7 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/memutil",
+ "//pkg/metric",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 807e4f44a..33e52ce64 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -62,6 +63,8 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating platform: %v", err)
}
+ metric.CreateSentryMetrics()
+
kernel.VFS2Enabled = true
k := &kernel.Kernel{
Platform: plat,
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index cd849e87e..c45bddff6 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -488,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
file := fd.inode().impl.(*regularFile)
+ opts.SentryOwnedContent = true
return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 2da251233..d473a922d 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -18,10 +18,12 @@ go_library(
"//pkg/marshal/primitive",
"//pkg/merkletree",
"//pkg/refsvfs2",
+ "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 6cb1a23e0..3582d14c9 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -168,10 +168,6 @@ afterSymlink:
// Preconditions:
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
-//
-// TODO(b/166474175): Investigate all possible errors returned in this
-// function, and make sure we differentiate all errors that indicate unexpected
-// modifications to the file system from the ones that are not harmful.
func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -200,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -209,7 +205,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's hash.
@@ -223,12 +219,14 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
}
+ defer parentMerkleFD.DecRef(ctx)
+
// dataSize is the size of raw data for the Merkle tree. For a file,
// dataSize is the size of the whole file. For a directory, dataSize is
// the size of all its children's hashes.
@@ -241,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -251,7 +249,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := FileReadWriteSeeker{
@@ -264,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
Start: parent.lowerVD,
}, &vfs.StatOptions{})
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -276,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
var buf bytes.Buffer
parent.hashMu.RLock()
_, err = merkletree.Verify(&merkletree.VerifyParams{
- Out: &buf,
- File: &fdReader,
- Tree: &fdReader,
- Size: int64(parentSize),
- Name: parent.name,
- Mode: uint32(parentStat.Mode),
- UID: parentStat.UID,
- GID: parentStat.GID,
- Children: parent.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ Children: parent.childrenNames,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
@@ -294,7 +291,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
})
parent.hashMu.RUnlock()
if err != nil && err != io.EOF {
- return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
@@ -331,19 +328,21 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
return err
}
+ defer fd.DecRef(ctx)
+
merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
Size: sizeOfStringInt32,
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return err
@@ -351,7 +350,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
size, err := strconv.Atoi(merkleSize)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
if d.isDir() && len(d.childrenNames) == 0 {
@@ -361,14 +360,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err))
}
if err != nil {
return err
}
childrenOffset, err := strconv.Atoi(childrenOffString)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
}
childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
@@ -377,23 +376,23 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err))
}
if err != nil {
return err
}
childrenSize, err := strconv.Atoi(childrenSizeString)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
}
childrenNames := make([]byte, childrenSize)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err))
}
if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err))
}
}
@@ -405,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
var buf bytes.Buffer
d.hashMu.RLock()
params := &merkletree.VerifyParams{
- Out: &buf,
- Tree: &fdReader,
- Size: int64(size),
- Name: d.name,
- Mode: uint32(stat.Mode),
- UID: stat.UID,
- GID: stat.GID,
- Children: d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ Children: d.childrenNames,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
@@ -438,7 +436,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
}
if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
- return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
}
d.mode = uint32(stat.Mode)
d.uid = stat.UID
@@ -471,7 +469,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// The file was previously accessed. If the
// file does not exist now, it indicates an
// unexpected modification to the file system.
- return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
}
if err != nil {
return nil, err
@@ -483,7 +481,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// does not exist now, it indicates an unexpected
// modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
}
if err != nil {
return nil, err
@@ -553,8 +551,8 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
childVD, err := parent.getLowerAt(ctx, vfsObj, name)
- if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name))
+ if parent.verityEnabled() && err == syserror.ENOENT {
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name))
}
if err != nil {
return nil, err
@@ -565,30 +563,31 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
defer childVD.DecRef(ctx)
childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
- if err == syserror.ENOENT {
- if !fs.allowRuntimeEnable {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name))
- }
- childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(merklePrefix + name),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR | linux.O_CREAT,
- Mode: 0644,
- })
- if err != nil {
- return nil, err
- }
- childMerkleFD.DecRef(ctx)
- childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
- if err != nil {
+ if err != nil {
+ if err == syserror.ENOENT {
+ if parent.verityEnabled() {
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name))
+ }
+ childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parent.lowerVD,
+ Start: parent.lowerVD,
+ Path: fspath.Parse(merklePrefix + name),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, err
+ }
+ childMerkleFD.DecRef(ctx)
+ childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
+ if err != nil {
+ return nil, err
+ }
+ } else {
return nil, err
}
}
- if err != nil && err != syserror.ENOENT {
- return nil, err
- }
// Clear the Merkle tree file if they are to be generated at runtime.
// TODO(b/182315468): Optimize the Merkle tree generate process to
@@ -632,8 +631,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childVD.IncRef()
childMerkleVD.IncRef()
- parent.IncRef()
- child.parent = parent
child.name = name
child.mode = uint32(stat.Mode)
@@ -657,6 +654,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
}
+ parent.IncRef()
+ child.parent = parent
+
return child, nil
}
@@ -855,7 +855,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// missing, it indicates an unexpected modification to the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
}
@@ -878,7 +878,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -903,7 +903,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
})
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -921,7 +921,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err != nil {
if err == syserror.ENOENT {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
return nil, err
}
@@ -985,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
// StatAt implements vfs.FilesystemImpl.StatAt.
-// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should
-// be verified.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index a7d92a878..31d34ef60 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -34,6 +34,8 @@
package verity
import (
+ "bytes"
+ "encoding/hex"
"encoding/json"
"fmt"
"math"
@@ -44,19 +46,20 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -95,14 +98,18 @@ const (
)
var (
- // action specifies the action towards detected violation.
- action ViolationAction
-
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
verityMu sync.RWMutex
)
+// Mount option names for verityfs.
+const (
+ moptLowerPath = "lower_path"
+ moptRootHash = "root_hash"
+ moptRootName = "root_name"
+)
+
// HashAlgorithm is a type specifying the algorithm used to hash the file
// content.
type HashAlgorithm int
@@ -169,6 +176,12 @@ type filesystem struct {
// system.
alg HashAlgorithm
+ // action specifies the action towards detected violation.
+ action ViolationAction
+
+ // opts is the string mount options passed to opts.Data.
+ opts string
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -191,9 +204,6 @@ type filesystem struct {
//
// +stateify savable
type InternalFilesystemOptions struct {
- // RootMerkleFileName is the name of the verity root Merkle tree file.
- RootMerkleFileName string
-
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
@@ -201,9 +211,6 @@ type InternalFilesystemOptions struct {
// system.
Alg HashAlgorithm
- // RootHash is the root hash of the overall verity file system.
- RootHash []byte
-
// AllowRuntimeEnable specifies whether the verity file system allows
// enabling verification for files (i.e. building Merkle trees) during
// runtime.
@@ -228,8 +235,8 @@ func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In ErrorOnViolation
// mode, it returns EIO, otherwise it panic.
-func alertIntegrityViolation(msg string) error {
- if action == ErrorOnViolation {
+func (fs *filesystem) alertIntegrityViolation(msg string) error {
+ if fs.action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -237,28 +244,99 @@ func alertIntegrityViolation(msg string) error {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var rootHash []byte
+ if encodedRootHash, ok := mopts[moptRootHash]; ok {
+ delete(mopts, moptRootHash)
+ hash, err := hex.DecodeString(encodedRootHash)
+ if err != nil {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err)
+ return nil, nil, syserror.EINVAL
+ }
+ rootHash = hash
+ }
+ var lowerPathname string
+ if path, ok := mopts[moptLowerPath]; ok {
+ delete(mopts, moptLowerPath)
+ lowerPathname = path
+ }
+ rootName := "root"
+ if root, ok := mopts[moptRootName]; ok {
+ delete(mopts, moptRootName)
+ rootName = root
+ }
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Handle internal options.
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
- if !ok {
+ if len(lowerPathname) == 0 && !ok {
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- action = iopts.Action
-
- // Mount the lower file system. The lower file system is wrapped inside
- // verity, and should not be exposed or connected.
- mopts := &vfs.MountOptions{
- GetFilesystemOptions: iopts.LowerGetFSOptions,
- InternalMount: true,
+ if len(lowerPathname) != 0 {
+ if ok {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path")
+ return nil, nil, syserror.EINVAL
+ }
+ iopts = InternalFilesystemOptions{
+ AllowRuntimeEnable: len(rootHash) == 0,
+ Action: ErrorOnViolation,
+ }
}
- mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts)
- if err != nil {
- return nil, nil, err
+
+ var lowerMount *vfs.Mount
+ var mountedLowerVD vfs.VirtualDentry
+ // Use an existing mount if lowerPath is provided.
+ if len(lowerPathname) != 0 {
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ var err error
+ mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ lowerMount = mountedLowerVD.Mount()
+ defer mountedLowerVD.DecRef(ctx)
+ } else {
+ // Mount the lower file system. The lower file system is wrapped inside
+ // verity, and should not be exposed or connected.
+ mountOpts := &vfs.MountOptions{
+ GetFilesystemOptions: iopts.LowerGetFSOptions,
+ InternalMount: true,
+ }
+ mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts)
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerMount = mnt
}
fs := &filesystem{
creds: creds.Fork(),
alg: iopts.Alg,
- lowerMount: mnt,
+ lowerMount: lowerMount,
+ action: iopts.Action,
+ opts: opts.Data,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -266,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Construct the root dentry.
d := fs.newDentry()
d.refs = 1
- lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root())
+ lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root())
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + rootName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -311,7 +389,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// the root Merkle file, or it's never generated.
fs.vfsfs.DecRef(ctx)
d.DecRef(ctx)
- return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
+ return nil, nil, fs.alertIntegrityViolation("Failed to find root Merkle file")
}
// Clear the Merkle tree file if they are to be generated at runtime.
@@ -350,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
- d.hash = make([]byte, len(iopts.RootHash))
d.childrenNames = make(map[string]struct{})
+ d.hashMu.Lock()
+ d.hash = make([]byte, len(rootHash))
+ copy(d.hash, rootHash)
+ d.hashMu.Unlock()
+
+ fs.rootDentry = d
+
if !d.isDir() {
ctx.Warningf("verity root must be a directory")
return nil, nil, syserror.EINVAL
@@ -368,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Size: sizeOfStringInt32,
})
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err))
}
if err != nil {
return nil, nil, err
@@ -376,7 +460,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
off, err := strconv.Atoi(offString)
if err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
}
sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{
@@ -387,14 +471,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Size: sizeOfStringInt32,
})
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err))
}
if err != nil {
return nil, nil, err
}
size, err := strconv.Atoi(sizeString)
if err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
}
lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
@@ -404,19 +488,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err))
}
if err != nil {
return nil, nil, err
}
+ defer lowerMerkleFD.DecRef(ctx)
+
childrenNames := make([]byte, size)
if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err))
}
if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
}
if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
@@ -424,13 +510,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
- d.hashMu.Lock()
- copy(d.hash, iopts.RootHash)
- d.hashMu.Unlock()
d.vfsd.Init(d)
- fs.rootDentry = d
-
return &fs.vfsfs, &d.vfsd, nil
}
@@ -441,7 +522,7 @@ func (fs *filesystem) Release(ctx context.Context) {
// MountOptions implements vfs.FilesystemImpl.MountOptions.
func (fs *filesystem) MountOptions() string {
- return ""
+ return fs.opts
}
// dentry implements vfs.DentryImpl.
@@ -722,6 +803,10 @@ type fileDescription struct {
// underlying file system.
lowerFD *vfs.FileDescription
+ // lowerMappable is the memmap.Mappable corresponding to this file in the
+ // underlying file system.
+ lowerMappable memmap.Mappable
+
// merkleReader is the read-only FileDescription corresponding to the
// Merkle tree file in the underlying file system.
merkleReader *vfs.FileDescription
@@ -755,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) {
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- // TODO(b/162788573): Add integrity check for metadata.
stat, err := fd.lowerFD.Stat(ctx, opts)
if err != nil {
return linux.Statx{}, err
@@ -794,7 +878,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
// Verify that the child is expected.
if dirent.Name != "." && dirent.Name != ".." {
if _, ok := fd.d.childrenNames[dirent.Name]; !ok {
- return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
+ return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
}
}
}
@@ -808,7 +892,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
// The result should contain all children plus "." and "..".
if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
- return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
+ return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
}
for fd.off < int64(len(ds)) {
@@ -875,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui
}
params := &merkletree.GenerateParams{
- TreeReader: &merkleReader,
- TreeWriter: &merkleWriter,
- Children: fd.d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ TreeReader: &merkleReader,
+ TreeWriter: &merkleWriter,
+ Children: fd.d.childrenNames,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
Name: fd.d.name,
Mode: uint32(stat.Mode),
@@ -980,7 +1063,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
// or directory other than the root, the parent Merkle tree file should
// have also been initialized.
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
- return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
+ return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
hash, dataSize, err := fd.generateMerkleLocked(ctx)
@@ -1053,7 +1136,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hosta
if fd.d.fs.allowRuntimeEnable {
return 0, syserror.ENODATA
}
- return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found")
+ return 0, fd.d.fs.alertIntegrityViolation("Ioctl measureVerity: no hash found")
}
// The first part of VerityDigest is the metadata.
@@ -1107,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
case linux.FS_IOC_GETFLAGS:
return fd.verityFlags(ctx, args[2].Pointer())
default:
- // TODO(b/169682228): Investigate which ioctl commands should
- // be allowed.
return 0, syserror.ENOSYS
}
}
@@ -1143,7 +1224,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
return 0, err
@@ -1153,7 +1234,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
dataReader := FileReadWriteSeeker{
@@ -1168,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
fd.d.hashMu.RLock()
n, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: dst.Writer(ctx),
- File: &dataReader,
- Tree: &merkleReader,
- Size: int64(size),
- Name: fd.d.name,
- Mode: fd.d.mode,
- UID: fd.d.uid,
- GID: fd.d.gid,
- Children: fd.d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ Children: fd.d.childrenNames,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
@@ -1186,7 +1266,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
})
fd.d.hashMu.RUnlock()
if err != nil {
- return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
@@ -1201,6 +1281,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op
return 0, syserror.EROFS
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+ fd.lowerMappable = opts.Mappable
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(ctx)
+ opts.MappingIdentity = nil
+ }
+
+ // Check if mmap is allowed on the lower filesystem.
+ if !opts.SentryOwnedContent {
+ return syserror.ENODEV
+ }
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
+}
+
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return fd.lowerFD.LockBSD(ctx, ownerPID, t, block)
@@ -1226,6 +1324,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t
return fd.lowerFD.TestPOSIX(ctx, uid, t, r)
}
+// Translate implements memmap.Mappable.Translate.
+func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
+ ts, err := fd.lowerMappable.Translate(ctx, required, optional, at)
+ if err != nil {
+ return nil, err
+ }
+
+ // dataSize is the size of the whole file.
+ dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ // The Merkle tree file for the child should have been created and
+ // contains the expected xattrs. If the xattr does not exist, it
+ // indicates unexpected modifications to the file system.
+ if err == syserror.ENODATA {
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ // The dataSize xattr should be an integer. If it's not, it indicates
+ // unexpected modifications to the file system.
+ size, err := strconv.Atoi(dataSize)
+ if err != nil {
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ }
+
+ merkleReader := FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+
+ for _, t := range ts {
+ // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed
+ // to fetch sentry owned memory because we disallow verity mmaps otherwise.
+ ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read)
+ if err != nil {
+ return nil, err
+ }
+ dataReader := mmapReadSeeker{ims, t.Source.Start}
+ var buf bytes.Buffer
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
+ ReadOffset: int64(t.Source.Start),
+ ReadSize: int64(t.Source.Length()),
+ Expected: fd.d.hash,
+ DataAndTreeInSameFile: false,
+ })
+ if err != nil {
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ }
+ }
+ return ts, err
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable)
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
+ fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable)
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (fd *fileDescription) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
+// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass
+// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt.
+type mmapReadSeeker struct {
+ safemem.BlockSeq
+ Offset uint64
+}
+
+// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file.
+func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) {
+ bs := r.BlockSeq
+ // Adjust the offset into the mapped file to get the offset into the internally
+ // mapped region.
+ readOffset := off - int64(r.Offset)
+ if readOffset < 0 {
+ return 0, syserror.EINVAL
+ }
+ bs.DropFirst64(uint64(readOffset))
+ view := bs.TakeFirst64(uint64(len(p)))
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p))
+ n, err := safemem.CopySeq(dst, view)
+ return int(n), err
+}
+
// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
type FileReadWriteSeeker struct {
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 57bd65202..5c78a0019 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
AllowUserMount: true,
})
+ data := "root_name=" + rootMerkleFilename
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: data,
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
Alg: hashAlg,
AllowRuntimeEnable: true,
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e9eb89378..a1ec6daab 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -141,6 +141,7 @@ go_library(
srcs = [
"abstract_socket_namespace.go",
"aio.go",
+ "cgroup.go",
"context.go",
"fd_table.go",
"fd_table_refs.go",
@@ -178,6 +179,7 @@ go_library(
"task.go",
"task_acct.go",
"task_block.go",
+ "task_cgroup.go",
"task_clone.go",
"task_context.go",
"task_exec.go",
@@ -241,6 +243,7 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/fsimpl/timerfd",
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
new file mode 100644
index 000000000..1f1c63f37
--- /dev/null
+++ b/pkg/sentry/kernel/cgroup.go
@@ -0,0 +1,281 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID.
+const InvalidCgroupHierarchyID uint32 = 0
+
+// CgroupControllerType is the name of a cgroup controller.
+type CgroupControllerType string
+
+// CgroupController is the common interface to cgroup controllers available to
+// the entire sentry. The controllers themselves are defined by cgroupfs.
+//
+// Callers of this interface are often unable access synchronization needed to
+// ensure returned values remain valid. Some of values returned from this
+// interface are thus snapshots in time, and may become stale. This is ok for
+// many callers like procfs.
+type CgroupController interface {
+ // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned
+ // value is valid for the lifetime of the controller.
+ Type() CgroupControllerType
+
+ // Hierarchy returns the ID of the hierarchy this cgroup controller is
+ // attached to. Returned value is valid for the lifetime of the controller.
+ HierarchyID() uint32
+
+ // Filesystem returns the filesystem this controller is attached to.
+ // Returned value is valid for the lifetime of the controller.
+ Filesystem() *vfs.Filesystem
+
+ // RootCgroup returns the root cgroup for this controller. Returned value is
+ // valid for the lifetime of the controller.
+ RootCgroup() Cgroup
+
+ // NumCgroups returns the number of cgroups managed by this controller.
+ // Returned value is a snapshot in time.
+ NumCgroups() uint64
+
+ // Enabled returns whether this controller is enabled. Returned value is a
+ // snapshot in time.
+ Enabled() bool
+}
+
+// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters
+// a cgroup, it holds a reference on the underlying dentry pointing to the
+// cgroup.
+//
+// +stateify savable
+type Cgroup struct {
+ *kernfs.Dentry
+ CgroupImpl
+}
+
+func (c *Cgroup) decRef() {
+ c.Dentry.DecRef(context.Background())
+}
+
+// Path returns the absolute path of c, relative to its hierarchy root.
+func (c *Cgroup) Path() string {
+ return c.FSLocalPath()
+}
+
+// HierarchyID returns the id of the hierarchy that contains this cgroup.
+func (c *Cgroup) HierarchyID() uint32 {
+ // Note: a cgroup is guaranteed to have at least one controller.
+ return c.Controllers()[0].HierarchyID()
+}
+
+// CgroupImpl is the common interface to cgroups.
+type CgroupImpl interface {
+ Controllers() []CgroupController
+ Enter(t *Task)
+ Leave(t *Task)
+}
+
+// hierarchy represents a cgroupfs filesystem instance, with a unique set of
+// controllers attached to it. Multiple cgroupfs mounts may reference the same
+// hierarchy.
+//
+// +stateify savable
+type hierarchy struct {
+ id uint32
+ // These are a subset of the controllers in CgroupRegistry.controllers,
+ // grouped here by hierarchy for conveninent lookup.
+ controllers map[CgroupControllerType]CgroupController
+ // fs is not owned by hierarchy. The FS is responsible for unregistering the
+ // hierarchy on destruction, which removes this association.
+ fs *vfs.Filesystem
+}
+
+func (h *hierarchy) match(ctypes []CgroupControllerType) bool {
+ if len(ctypes) != len(h.controllers) {
+ return false
+ }
+ for _, ty := range ctypes {
+ if _, ok := h.controllers[ty]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+// CgroupRegistry tracks the active set of cgroup controllers on the system.
+//
+// +stateify savable
+type CgroupRegistry struct {
+ // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid
+ // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops.
+ //
+ lastHierarchyID uint32
+
+ mu sync.Mutex `state:"nosave"`
+
+ // controllers is the set of currently known cgroup controllers on the
+ // system. Protected by mu.
+ //
+ // +checklocks:mu
+ controllers map[CgroupControllerType]CgroupController
+
+ // hierarchies is the active set of cgroup hierarchies. Protected by mu.
+ //
+ // +checklocks:mu
+ hierarchies map[uint32]hierarchy
+}
+
+func newCgroupRegistry() *CgroupRegistry {
+ return &CgroupRegistry{
+ controllers: make(map[CgroupControllerType]CgroupController),
+ hierarchies: make(map[uint32]hierarchy),
+ }
+}
+
+// nextHierarchyID returns a newly allocated, unique hierarchy ID.
+func (r *CgroupRegistry) nextHierarchyID() (uint32, error) {
+ if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 {
+ return hid, nil
+ }
+ return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow")
+}
+
+// FindHierarchy returns a cgroup filesystem containing exactly the set of
+// controllers named in names. If no such FS is found, FindHierarchy return
+// nil. FindHierarchy takes a reference on the returned FS, which is transferred
+// to the caller.
+func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ for _, h := range r.hierarchies {
+ if h.match(ctypes) {
+ h.fs.IncRef()
+ return h.fs
+ }
+ }
+
+ return nil
+}
+
+// Register registers the provided set of controllers with the registry as a new
+// hierarchy. If any controller is already registered, the function returns an
+// error without modifying the registry. The hierarchy can be later referenced
+// by the returned id.
+func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(cs) == 0 {
+ return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers")
+ }
+
+ for _, c := range cs {
+ if _, ok := r.controllers[c.Type()]; ok {
+ return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy")
+ }
+ }
+
+ hid, err := r.nextHierarchyID()
+ if err != nil {
+ return hid, err
+ }
+
+ h := hierarchy{
+ id: hid,
+ controllers: make(map[CgroupControllerType]CgroupController),
+ fs: cs[0].Filesystem(),
+ }
+ for _, c := range cs {
+ n := c.Type()
+ r.controllers[n] = c
+ h.controllers[n] = c
+ }
+ r.hierarchies[hid] = h
+ return hid, nil
+}
+
+// Unregister removes a previously registered hierarchy from the registry. If
+// the controller was not previously registered, Unregister is a no-op.
+func (r *CgroupRegistry) Unregister(hid uint32) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if h, ok := r.hierarchies[hid]; ok {
+ for name, _ := range h.controllers {
+ delete(r.controllers, name)
+ }
+ delete(r.hierarchies, hid)
+ }
+}
+
+// computeInitialGroups takes a reference on each of the returned cgroups. The
+// caller takes ownership of this returned reference.
+func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ ctlSet := make(map[CgroupControllerType]CgroupController)
+ cgset := make(map[Cgroup]struct{})
+
+ // Remember controllers from the inherited cgroups set...
+ for cg, _ := range inherit {
+ cg.IncRef() // Ref transferred to caller.
+ for _, ctl := range cg.Controllers() {
+ ctlSet[ctl.Type()] = ctl
+ cgset[cg] = struct{}{}
+ }
+ }
+
+ // ... and add the root cgroups of all the missing controllers.
+ for name, ctl := range r.controllers {
+ if _, ok := ctlSet[name]; !ok {
+ cg := ctl.RootCgroup()
+ cg.IncRef() // Ref transferred to caller.
+ cgset[cg] = struct{}{}
+ }
+ }
+ return cgset
+}
+
+// GenerateProcCgroups writes the contents of /proc/cgroups to buf.
+func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) {
+ r.mu.Lock()
+ entries := make([]string, 0, len(r.controllers))
+ for _, c := range r.controllers {
+ en := 0
+ if c.Enabled() {
+ en = 1
+ }
+ entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en))
+ }
+ r.mu.Unlock()
+
+ sort.Strings(entries)
+ fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n")
+ for _, e := range entries {
+ fmt.Fprint(buf, e)
+ }
+}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 43065b45a..e6e9da898 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -294,6 +294,11 @@ type Kernel struct {
// YAMAPtraceScope is the current level of YAMA ptrace restrictions.
YAMAPtraceScope int32
+
+ // cgroupRegistry contains the set of active cgroup controllers on the
+ // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on
+ // the system.
+ cgroupRegistry *CgroupRegistry
}
// InitKernelArgs holds arguments to Init.
@@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.socketMount = socketMount
k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
+
+ k.cgroupRegistry = newCgroupRegistry()
}
return nil
}
@@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+// CgroupRegistry returns the cgroup registry.
+func (k *Kernel) CgroupRegistry() *CgroupRegistry {
+ return k.cgroupRegistry
+}
+
// Release releases resources owned by k.
//
// Precondition: This should only be called after the kernel is fully
@@ -1831,3 +1843,43 @@ func (k *Kernel) Release() {
k.timekeeper.Destroy()
k.vdso.Release(ctx)
}
+
+// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup
+// hierarchy.
+//
+// Precondition: root must be a new cgroup with no tasks. This implies the
+// controllers for root are also new and currently manage no task, which in turn
+// implies the new cgroup can be populated without migrating tasks between
+// cgroups.
+func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.exitState != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ t.enterCgroupLocked(root)
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
+
+// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the
+// hierarchy with the provided id. This is intended for use during hierarchy
+// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers.
+func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.exitState != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ for cg, _ := range t.cgroups {
+ if cg.HierarchyID() == hid {
+ t.leaveCgroupLocked(cg)
+ }
+ }
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 399985039..be1371855 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -587,6 +587,12 @@ type Task struct {
//
// kcov is exclusive to the task goroutine.
kcov *Kcov
+
+ // cgroups is the set of cgroups this task belongs to. This may be empty if
+ // no cgroup controllers are enabled. Protected by mu.
+ //
+ // +checklocks:mu
+ cgroups map[Cgroup]struct{}
}
func (t *Task) savePtraceTracer() *Task {
diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go
new file mode 100644
index 000000000..25d2504fa
--- /dev/null
+++ b/pkg/sentry/kernel/task_cgroup.go
@@ -0,0 +1,138 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// EnterInitialCgroups moves t into an initial set of cgroups.
+//
+// Precondition: t isn't in any cgroups yet, t.cgs is empty.
+//
+// +checklocksignore parent.mu is conditionally acquired.
+func (t *Task) EnterInitialCgroups(parent *Task) {
+ var inherit map[Cgroup]struct{}
+ if parent != nil {
+ parent.mu.Lock()
+ defer parent.mu.Unlock()
+ inherit = parent.cgroups
+ }
+ joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit)
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // Transfer ownership of joinSet refs to the task's cgset.
+ t.cgroups = joinSet
+ for c, _ := range t.cgroups {
+ // Since t isn't in any cgroup yet, we can skip the check against
+ // existing cgroups.
+ c.Enter(t)
+ }
+}
+
+// EnterCgroup moves t into c.
+func (t *Task) EnterCgroup(c Cgroup) error {
+ newControllers := make(map[CgroupControllerType]struct{})
+ for _, ctl := range c.Controllers() {
+ newControllers[ctl.Type()] = struct{}{}
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ for oldCG, _ := range t.cgroups {
+ for _, oldCtl := range oldCG.Controllers() {
+ if _, ok := newControllers[oldCtl.Type()]; ok {
+ // Already in a cgroup with the same controller as one of the
+ // new ones. Requires migration between cgroups.
+ //
+ // TODO(b/183137098): Implement cgroup migration.
+ log.Warningf("Cgroup migration is not implemented")
+ return syserror.EBUSY
+ }
+ }
+ }
+
+ // No migration required.
+ t.enterCgroupLocked(c)
+
+ return nil
+}
+
+// +checklocks:t.mu
+func (t *Task) enterCgroupLocked(c Cgroup) {
+ c.IncRef()
+ t.cgroups[c] = struct{}{}
+ c.Enter(t)
+}
+
+// LeaveCgroups removes t out from all its cgroups.
+func (t *Task) LeaveCgroups() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ for c, _ := range t.cgroups {
+ t.leaveCgroupLocked(c)
+ }
+}
+
+// +checklocks:t.mu
+func (t *Task) leaveCgroupLocked(c Cgroup) {
+ c.Leave(t)
+ delete(t.cgroups, c)
+ c.decRef()
+}
+
+// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to
+// format a cgroup for display.
+type taskCgroupEntry struct {
+ hierarchyID uint32
+ controllers string
+ path string
+}
+
+// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf.
+func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups))
+ for c, _ := range t.cgroups {
+ ctls := c.Controllers()
+ ctlNames := make([]string, 0, len(ctls))
+ for _, ctl := range ctls {
+ ctlNames = append(ctlNames, string(ctl.Type()))
+ }
+
+ cgEntries = append(cgEntries, taskCgroupEntry{
+ // Note: We're guaranteed to have at least one controller, and all
+ // controllers are guaranteed to be on the same hierarchy.
+ hierarchyID: ctls[0].HierarchyID(),
+ controllers: strings.Join(ctlNames, ","),
+ path: c.Path(),
+ })
+ }
+
+ sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID })
+ for _, cgE := range cgEntries {
+ fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path)
+ }
+}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index ad59e4f60..b1af1a7ef 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.fsContext.DecRef(t)
t.fdTable.DecRef(t)
+ // Detach task from all cgroups. This must happen before potentially the
+ // last ref to the cgroupfs mount is dropped below.
+ t.LeaveCgroups()
+
t.mu.Lock()
if t.mountNamespaceVFS2 != nil {
t.mountNamespaceVFS2.DecRef(t)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index fc18b6253..32031cd70 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
+ cgroups: make(map[Cgroup]struct{}),
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
@@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
t.parent.children[t] = struct{}{}
}
+ if VFS2Enabled {
+ t.EnterInitialCgroups(t.parent)
+ }
+
if tg.leader == nil {
// New thread group.
tg.leader = t
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 2c658d001..601fc0d3a 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -30,8 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application")
-
// SyscallRestartBlock represents the restart block for a syscall restartable
// with a custom function. It encapsulates the state required to restart a
// syscall across a S/R.
@@ -284,7 +282,7 @@ func (*runSyscallExit) execute(t *Task) taskRunState {
// indicated by an execution fault at address addr. doVsyscall returns the
// task's next run state.
func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState {
- vsyscallCount.Increment()
+ metric.WeirdnessMetric.Increment("vsyscall_count")
// Grab the caller up front, to make sure there's a sensible stack.
caller := t.Arch().Native(uintptr(0))
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 09d070ec8..77ad62445 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) {
}
}
+// forEachTaskLocked applies f to each Task in ts.
+//
+// Preconditions: ts.mu must be locked (for reading or writing).
+func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) {
+ for t := range ts.Root.tids {
+ f(t)
+ }
+}
+
// A PIDNamespace represents a PID namespace, a bimap between thread IDs and
// tasks. See the pid_namespaces(7) man page for further details.
//
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index ecb6603a1..4c65215fa 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -11,11 +11,12 @@ go_library(
"vdso.go",
"vdso_state.go",
],
+ marshal = True,
+ marshal_debug = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi",
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/cpuid",
"//pkg/hostarch",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index e92d9fdc3..8fc3e2a79 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -47,10 +46,10 @@ const (
var (
// header64Size is the size of elf.Header64.
- header64Size = int(binary.Size(elf.Header64{}))
+ header64Size = (*linux.ElfHeader64)(nil).SizeBytes()
// Prog64Size is the size of elf.Prog64.
- prog64Size = int(binary.Size(elf.Prog64{}))
+ prog64Size = (*linux.ElfProg64)(nil).SizeBytes()
)
func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType {
@@ -136,7 +135,6 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
log.Infof("Unsupported ELF endianness: %v", endian)
return elfInfo{}, syserror.ENOEXEC
}
- byteOrder := binary.LittleEndian
if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT {
log.Infof("Unsupported ELF version: %v", version)
@@ -145,7 +143,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// EI_OSABI is ignored by Linux, which is the only OS supported.
os := abi.Linux
- var hdr elf.Header64
+ var hdr linux.ElfHeader64
hdrBuf := make([]byte, header64Size)
_, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0)
if err != nil {
@@ -156,7 +154,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
}
return elfInfo{}, err
}
- binary.Unmarshal(hdrBuf, byteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBuf)
// We support amd64 and arm64.
var a arch.Arch
@@ -213,8 +211,8 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
phdrs := make([]elf.ProgHeader, hdr.Phnum)
for i := range phdrs {
- var prog64 elf.Prog64
- binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64)
+ var prog64 linux.ElfProg64
+ prog64.UnmarshalUnsafe(phdrBuf[:prog64Size])
phdrBuf = phdrBuf[prog64Size:]
phdrs[i] = elf.ProgHeader{
Type: elf.ProgType(prog64.Type),
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 72868646a..610686ea0 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -375,6 +375,11 @@ type MMapOpts struct {
//
// If Force is true, Unmap and Fixed must be true.
Force bool
+
+ // SentryOwnedContent indicates the sentry exclusively controls the
+ // underlying memory backing the mapping thus the memory content is
+ // guaranteed not to be modified outside the sentry's purview.
+ SentryOwnedContent bool
}
// File represents a host file that may be mapped into an platform.AddressSpace.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index f04898dc1..b307832fd 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -65,6 +65,7 @@ go_test(
name = "kvm_test",
srcs = [
"kvm_amd64_test.go",
+ "kvm_amd64_test.s",
"kvm_arm64_test.go",
"kvm_test.go",
"virtual_map_test.go",
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index fd1131638..bb9967b9f 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -16,7 +16,6 @@ package kvm
import (
"fmt"
- "reflect"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/ring0"
@@ -36,6 +35,14 @@ func sighandler()
// dieArchSetup and the assembly implementation for dieTrampoline.
func dieTrampoline()
+// Return the start address of the functions above.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfSighandler() uintptr
+func addrOfDieTrampoline() uintptr
+
var (
// bounceSignal is the signal used for bouncing KVM.
//
@@ -87,10 +94,10 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
// Extract the address for the trampoline.
- dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer()
+ dieTrampolineAddr = addrOfDieTrampoline()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 025ea93b5..953024600 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -81,8 +81,20 @@ fallback:
MOVQ ·savedHandler(SB), AX
JMP AX
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSighandler(SB), $0-8
+ MOVQ $·sighandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
PUSHQ BX // First argument (vCPU).
PUSHQ AX // Fake the old RIP as caller.
JMP ·dieHandler(SB)
+
+// func addrOfDieTrampoline() uintptr
+TEXT ·addrOfDieTrampoline(SB), $0-8
+ MOVQ $·dieTrampoline(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 09c7e88e5..308f2a951 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -92,6 +92,12 @@ fallback:
MOVD ·savedHandler(SB), R7
B (R7)
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSighandler(SB), $0-8
+ MOVD $·sighandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
// R0: Fake the old PC as caller
@@ -99,3 +105,9 @@ TEXT ·dieTrampoline(SB),NOSPLIT,$0
MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
B ·dieHandler(SB)
+
+// func addrOfDieTrampoline() uintptr
+TEXT ·addrOfDieTrampoline(SB), $0-8
+ MOVD $·dieTrampoline(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index e44e995a0..b8dd1e4a5 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -49,3 +49,40 @@ func TestSegments(t *testing.T) {
return false
})
}
+
+// stmxcsr reads the MXCSR control and status register.
+func stmxcsr(addr *uint32)
+
+func TestMXCSR(t *testing.T) {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ switchOpts := ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: &dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }
+
+ const mxcsrControllMask = uint32(0x1f80)
+ mxcsrBefore := uint32(0)
+ mxcsrAfter := uint32(0)
+ stmxcsr(&mxcsrBefore)
+ if mxcsrBefore == 0 {
+ // goruntime sets mxcsr to 0x1f80 and it never changes
+ // the control configuration.
+ panic("mxcsr is zero")
+ }
+ switchOpts.FloatingPointState.SetMXCSR(0)
+ if _, err := c.SwitchToUser(
+ switchOpts, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != nil {
+ t.Errorf("application syscall failed: %v", err)
+ }
+ stmxcsr(&mxcsrAfter)
+ if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask {
+ t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter)
+ }
+ return false
+ })
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.s b/pkg/sentry/platform/kvm/kvm_amd64_test.s
new file mode 100644
index 000000000..8e9079867
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.s
@@ -0,0 +1,21 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// stmxcsr reads the MXCSR control and status register.
+TEXT ·stmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ STMXCSR (SI)
+ RET
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 2492d57be..eb2dcccac 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -66,6 +66,7 @@ const (
_KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
_KVM_CAP_VCPU_EVENTS = 0x29
_KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e
+ _KVM_CAP_TSC_CONTROL = 0x3c
)
// KVM limits.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index b3d4188a3..1b5d5f66e 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -67,11 +67,17 @@ type machine struct {
// maxSlots is the maximum number of memory slots supported by the machine.
maxSlots int
+ // tscControl checks whether cpu supports TSC scaling
+ tscControl bool
+
// usedSlots is the set of used physical addresses (sorted).
usedSlots []uintptr
// nextID is the next vCPU ID.
nextID uint32
+
+ // machineArchState is the architecture-specific state.
+ machineArchState
}
const (
@@ -193,12 +199,7 @@ func newMachine(vm int) (*machine, error) {
m.available.L = &m.mu
// Pull the maximum vCPUs.
- maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
- if errno != 0 {
- m.maxVCPUs = _KVM_NR_VCPUS
- } else {
- m.maxVCPUs = int(maxVCPUs)
- }
+ m.getMaxVCPU()
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
@@ -214,6 +215,11 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of slots is %d.", m.maxSlots)
m.usedSlots = make([]uintptr, m.maxSlots)
+ // Check TSC Scaling
+ hasTSCControl, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_TSC_CONTROL)
+ m.tscControl = errno == 0 && hasTSCControl == 1
+ log.Debugf("TSC scaling support: %t.", m.tscControl)
+
// Create the upper shared pagetables and kernel(sentry) pagetables.
m.upperSharedPageTables = pagetables.New(newAllocator())
m.mapUpperHalf(m.upperSharedPageTables)
@@ -419,9 +425,8 @@ func (m *machine) Get() *vCPU {
}
}
- // Create a new vCPU (maybe).
- if int(m.nextID) < m.maxVCPUs {
- c := m.newVCPU()
+ // Get a new vCPU (maybe).
+ if c := m.getNewVCPU(); c != nil {
c.lock()
m.vCPUsByTID[tid] = c
m.mu.Unlock()
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index e8e209249..9a2337654 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -63,6 +63,9 @@ func (m *machine) initArchState() error {
return nil
}
+type machineArchState struct {
+}
+
type vCPUArchState struct {
// PCIDs is the set of PCIDs for this vCPU.
//
@@ -213,6 +216,11 @@ func (c *vCPU) setSystemTime() error {
// capabilities as it is emulated in KVM. We don't actually use this
// capability, but it means that this method should be robust to
// different hardware configurations.
+
+ // if tsc scaling is not supported, fallback to legacy mode
+ if !c.machine.tscControl {
+ return c.setSystemTimeLegacy()
+ }
rawFreq, err := c.getTSCFreq()
if err != nil {
return c.setSystemTimeLegacy()
@@ -346,6 +354,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
// allocations occur.
entersyscall()
bluepill(c)
+ // The root table physical page has to be mapped to not fault in iret
+ // or sysret after switching into a user address space. sysret and
+ // iret are in the upper half that is global and already mapped.
+ switchOpts.PageTables.PrefaultRootTable()
prefaultFloatingPointState(switchOpts.FloatingPointState)
vector = c.CPU.SwitchToUser(switchOpts)
exitsyscall()
@@ -490,3 +502,22 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
physical)
}
}
+
+// getMaxVCPU get max vCPU number
+func (m *machine) getMaxVCPU() {
+ maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
+ if errno != 0 {
+ m.maxVCPUs = _KVM_NR_VCPUS
+ } else {
+ m.maxVCPUs = int(maxVCPUs)
+ }
+}
+
+// getNewVCPU create a new vCPU (maybe)
+func (m *machine) getNewVCPU() *vCPU {
+ if int(m.nextID) < m.maxVCPUs {
+ c := m.newVCPU()
+ return c
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 03e84d804..8926b1d9f 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -17,6 +17,10 @@
package kvm
import (
+ "runtime"
+ "sync/atomic"
+
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
@@ -25,6 +29,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
)
+type machineArchState struct {
+ //initialvCPUs is the machine vCPUs which has initialized but not used
+ initialvCPUs map[int]*vCPU
+}
+
type vCPUArchState struct {
// PCIDs is the set of PCIDs for this vCPU.
//
@@ -47,7 +56,7 @@ const (
// Beyond a relatively small number, there are likely few perform
// benefits, since the TLB has likely long since lost any translations
// from more than a few PCIDs past.
- poolPCIDs = 8
+ poolPCIDs = 128
)
func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
@@ -182,3 +191,30 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType,
return accessType, platform.ErrContextSignal
}
+
+// getMaxVCPU get max vCPU number
+func (m *machine) getMaxVCPU() {
+ rmaxVCPUs := runtime.NumCPU()
+ smaxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
+ // compare the max vcpu number from runtime and syscall, use smaller one.
+ if errno != 0 {
+ m.maxVCPUs = rmaxVCPUs
+ } else {
+ if rmaxVCPUs < int(smaxVCPUs) {
+ m.maxVCPUs = rmaxVCPUs
+ } else {
+ m.maxVCPUs = int(smaxVCPUs)
+ }
+ }
+}
+
+// getNewVCPU() scan for an available vCPU from initialvCPUs
+func (m *machine) getNewVCPU() *vCPU {
+ for CID, c := range m.initialvCPUs {
+ if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
+ delete(m.initialvCPUs, CID)
+ return c
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 634e55ec0..92edc992b 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
type kvmVcpuInit struct {
@@ -47,6 +48,19 @@ func (m *machine) initArchState() error {
uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno))
}
+
+ // Initialize all vCPUs on ARM64, while this does not happen on x86_64.
+ // The reason for the difference is that ARM64 and x86_64 have different KVM timer mechanisms.
+ // If we create vCPU dynamically on ARM64, the timer for vCPU would mess up for a short time.
+ // For more detail, please refer to https://github.com/google/gvisor/issues/5739
+ m.initialvCPUs = make(map[int]*vCPU)
+ m.mu.Lock()
+ for int(m.nextID) < m.maxVCPUs-1 {
+ c := m.newVCPU()
+ c.state = 0
+ m.initialvCPUs[c.id] = c
+ }
+ m.mu.Unlock()
return nil
}
@@ -174,9 +188,58 @@ func (c *vCPU) setTSC(value uint64) error {
return nil
}
+// getTSC gets the counter Physical Counter minus Virtual Offset.
+func (c *vCPU) getTSC() error {
+ var (
+ reg kvmOneReg
+ data uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ reg.id = _KVM_ARM64_REGS_TIMER_CNT
+
+ if err := c.getOneRegister(&reg); err != nil {
+ return err
+ }
+
+ return nil
+}
+
// setSystemTime sets the vCPU to the system time.
func (c *vCPU) setSystemTime() error {
- return c.setSystemTimeLegacy()
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Use get the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ // replace getTSC to another setOneRegister syscall can get more accurate value?
+ start := uint64(ktime.Rdtsc())
+ if err := c.getTSC(); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && (current <= upperThreshold || minimum < 50) {
+ // Try to set the TSC
+ if err := c.setTSC(end + (minimum / 2)); err != nil {
+ return err
+ }
+ return nil
+ }
+ }
}
//go:nosplit
@@ -203,7 +266,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
uintptr(c.fd),
_KVM_GET_ONE_REG,
uintptr(unsafe.Pointer(reg))); errno != 0 {
- return fmt.Errorf("error setting one register: %v", errno)
+ return fmt.Errorf("error getting one register: %v", errno)
}
return nil
}
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 16f9c523e..d5c3f901f 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -109,6 +109,12 @@ parent_dead:
SYSCALL
HLT
+// func addrOfStub() uintptr
+TEXT ·addrOfStub(SB), $0-8
+ MOVQ $·stub(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// stubCall calls the stub function at the given address with the given PPID.
//
// This is a distinct function because stub, above, may be mapped at any
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 6162df02a..4664cd4ad 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -102,6 +102,12 @@ parent_dead:
SVC
HLT
+// func addrOfStub() uintptr
+TEXT ·addrOfStub(SB), $0-8
+ MOVD $·stub(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// stubCall calls the stub function at the given address with the given PPID.
//
// This is a distinct function because stub, above, may be mapped at any
diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go
index 5c9b7784f..1fbdea898 100644
--- a/pkg/sentry/platform/ptrace/stub_unsafe.go
+++ b/pkg/sentry/platform/ptrace/stub_unsafe.go
@@ -26,6 +26,13 @@ import (
// stub is defined in arch-specific assembly.
func stub()
+// addrOfStub returns the start address of stub.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfStub() uintptr
+
// stubCall calls the stub at the given address with the given pid.
func stubCall(addr, pid uintptr)
@@ -41,7 +48,7 @@ func unsafeSlice(addr uintptr, length int) (slice []byte) {
// stubInit initializes the stub.
func stubInit() {
// Grab the existing stub.
- stubBegin := reflect.ValueOf(stub).Pointer()
+ stubBegin := addrOfStub()
stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin)
stubSlice := unsafeSlice(stubBegin, stubLen)
mapLen := uintptr(stubLen)
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 080859125..7ee89a735 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 0e0e82365..2029e7cf4 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -14,9 +14,11 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/context",
"//pkg/hostarch",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 45a05cd63..235b9c306 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -18,9 +18,11 @@ package control
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := binary.AlignDown(cap(buf)-len(buf), 4)
+ space := bits.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
return alignSlice(buf, align), flags
}
-func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte {
+func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte {
if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
return buf
}
@@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = putUint32(buf, msgType)
hdrBuf := buf
-
- buf = binary.Marshal(buf, hostarch.ByteOrder, data)
+ buf = append(buf, marshal.Marshal(data)...)
// If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
@@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := binary.AlignUp(len(buf), align)
+ aligned := bits.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte {
// PackTimestamp packs a SO_TIMESTAMP socket control message.
func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
+ timestampP := linux.NsecToTimeval(timestamp)
return putCmsgStruct(
buf,
linux.SOL_SOCKET,
linux.SO_TIMESTAMP,
t.Arch().Width(),
- linux.NsecToTimeval(timestamp),
+ &timestampP,
)
}
@@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
linux.SOL_TCP,
linux.TCP_INQ,
t.Arch().Width(),
- inq,
+ primitive.AllocateInt32(inq),
)
}
@@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
linux.SOL_IP,
linux.IP_TOS,
t.Arch().Width(),
- tos,
+ primitive.AllocateUint8(tos),
)
}
@@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
linux.SOL_IPV6,
linux.IPV6_TCLASS,
t.Arch().Width(),
- tClass,
+ primitive.AllocateUint32(tClass),
)
}
@@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
+ h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
return socket.ControlMessages{}, syserror.EINVAL
@@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
+ creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
scmCreds, err := NewSCMCredentials(t, creds)
if err != nil {
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.SO_TIMESTAMP:
if length < linux.SizeOfTimeval {
return socket.ControlMessages{}, syserror.EINVAL
}
var ts linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts)
+ ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
cmsgs.IP.Timestamp = ts.ToNsecCapped()
cmsgs.IP.HasTimestamp = true
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
// Unknown message type.
@@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTOS = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS)
- i += binary.AlignUp(length, width)
+ var tos primitive.Uint8
+ tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS])
+ cmsgs.IP.TOS = uint8(tos)
+ i += bits.AlignUp(length, width)
case linux.IP_PKTINFO:
if length < linux.SizeOfControlMessageIPPacketInfo {
@@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
cmsgs.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
+ packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo])
cmsgs.IP.PacketInfo = packetInfo
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IP_RECVERR:
var errCmsg linux.SockErrCMsgIPv4
@@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
cmsgs.IP.SockErr = &errCmsg
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTClass = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass)
- i += binary.AlignUp(length, width)
+ var tclass primitive.Uint32
+ tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass])
+ cmsgs.IP.TClass = uint32(tclass)
+ i += bits.AlignUp(length, width)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IPV6_RECVERR:
var errCmsg linux.SockErrCMsgIPv6
@@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
cmsgs.IP.SockErr = &errCmsg
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a5c2155a2..2e3064565 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -17,7 +17,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/hostarch",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a784e23b5..52ae4bc9c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -528,24 +527,28 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.SO_TIMESTAMP:
controlMessages.IP.HasTimestamp = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], hostarch.ByteOrder, &controlMessages.IP.Timestamp)
+ ts := linux.Timeval{}
+ ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
+ controlMessages.IP.Timestamp = ts.ToNsecCapped()
}
case linux.SOL_IP:
switch unixCmsg.Header.Type {
case linux.IP_TOS:
controlMessages.IP.HasTOS = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS)
+ var tos primitive.Uint8
+ tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()])
+ controlMessages.IP.TOS = uint8(tos)
case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
+ packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()])
controlMessages.IP.PacketInfo = packetInfo
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
controlMessages.IP.OriginalDstAddress = &addr
case unix.IP_RECVERR:
@@ -558,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass)
+ var tclass primitive.Uint32
+ tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()])
+ controlMessages.IP.TClass = uint32(tclass)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
controlMessages.IP.OriginalDstAddress = &addr
case unix.IPV6_RECVERR:
@@ -575,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.TCP_INQ:
controlMessages.IP.HasInq = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq)
+ var inq primitive.Int32
+ inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq])
+ controlMessages.IP.Inq = int32(inq)
}
}
}
@@ -689,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 {
return 0
}
- binary.Unmarshal(buf, hostarch.ByteOrder, &info)
+ info.UnmarshalUnsafe(buf[:info.SizeBytes()])
return uint32(info.State)
}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 26e8ae17a..393a1ab3a 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -15,6 +15,7 @@
package hostinet
import (
+ "encoding/binary"
"fmt"
"io"
"io/ioutil"
@@ -26,10 +27,10 @@ import (
"syscall"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -147,8 +148,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
if len(link.Data) < unix.SizeofIfInfomsg {
return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg)
}
- var ifinfo unix.IfInfomsg
- binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo)
+ var ifinfo linux.InterfaceInfoMessage
+ ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()])
inetIF := inet.Interface{
DeviceType: ifinfo.Type,
Flags: ifinfo.Flags,
@@ -178,11 +179,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
if len(addr.Data) < unix.SizeofIfAddrmsg {
return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg)
}
- var ifaddr unix.IfAddrmsg
- binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr)
+ var ifaddr linux.InterfaceAddrMessage
+ ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()])
inetAddr := inet.InterfaceAddr{
Family: ifaddr.Family,
- PrefixLen: ifaddr.Prefixlen,
+ PrefixLen: ifaddr.PrefixLen,
Flags: ifaddr.Flags,
}
attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
@@ -210,13 +211,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
continue
}
- var ifRoute unix.RtMsg
- binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute)
+ var ifRoute linux.RouteMessage
+ ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()])
inetRoute := inet.Route{
Family: ifRoute.Family,
- DstLen: ifRoute.Dst_len,
- SrcLen: ifRoute.Src_len,
- TOS: ifRoute.Tos,
+ DstLen: ifRoute.DstLen,
+ SrcLen: ifRoute.SrcLen,
+ TOS: ifRoute.TOS,
Table: ifRoute.Table,
Protocol: ifRoute.Protocol,
Scope: ifRoute.Scope,
@@ -245,7 +246,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
if len(attr.Value) != expected {
return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
}
- binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface)
+ var outputIF primitive.Int32
+ outputIF.UnmarshalUnsafe(attr.Value)
+ inetRoute.OutputInterface = int32(outputIF)
}
}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 4381dfa06..61b2c9755 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -14,14 +14,16 @@ go_library(
"tcp_matcher.go",
"udp_matcher.go",
],
+ marshal = True,
# This target depends on netstack and should only be used by epsocket,
# which is allowed to depend on netstack.
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/hostarch",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/sentry/kernel",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 4bd305a44..6fc7781ad 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
nflog("marshaling matcher %q", name)
// We have to pad this struct size to a multiple of 8 bytes.
- size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
matcher := linux.KernelXTEntryMatch{
XTEntryMatch: linux.XTEntryMatch{
MatchSize: uint16(size),
@@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte {
}
copy(matcher.Name[:], name)
- buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, hostarch.ByteOrder, matcher)
- return append(buf, make([]byte, size-len(buf))...)
+ buf := make([]byte, size)
+ entryLen := matcher.XTEntryMatch.SizeBytes()
+ matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen])
+ copy(buf[entryLen:], matcher.Data)
+ return buf
}
func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 1fc4cb651..cb78ef60b 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -18,8 +18,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
var entry linux.IPTEntry
- buf := optVal[:linux.SizeOfIPTEntry]
- binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
+ entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[linux.SizeOfIPTEntry:]
+ optVal = optVal[entry.SizeBytes():]
if entry.TargetOffset < linux.SizeOfIPTEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 67a52b628..5cb7fe4aa 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -18,8 +18,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
var entry linux.IP6TEntry
- buf := optVal[:linux.SizeOfIP6TEntry]
- binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
+ entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[linux.SizeOfIP6TEntry:]
+ optVal = optVal[entry.SizeBytes():]
if entry.TargetOffset < linux.SizeOfIP6TEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 5200e08ed..f42d73178 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -22,7 +22,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe
nflog("couldn't read entries: %v", err)
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
- if binary.Size(entries) > uintptr(outLen) {
+ if entries.SizeBytes() > outLen {
nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
@@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe
nflog("couldn't read entries: %v", err)
return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
- if binary.Size(entries) > uintptr(outLen) {
+ if entries.SizeBytes() > outLen {
nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
@@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
- binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace)
+ replace.UnmarshalBytes(replaceBuf)
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
@@ -274,10 +273,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
- // make sure all other chains point to ACCEPT rules.
+ // Since we don't support FORWARD, yet, make sure all other chains point to
+ // ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward {
if ruleIdx == stack.HookUnset {
continue
}
@@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
}
var match linux.XTEntryMatch
- buf := optVal[:linux.SizeOfXTEntryMatch]
- binary.Unmarshal(buf, hostarch.ByteOrder, &match)
+ buf := optVal[:match.SizeBytes()]
+ match.UnmarshalUnsafe(buf)
nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
// Check some invariants.
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index b2cc6be20..60845cab3 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte {
}
}
- buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
- return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo))
+ buf := marshal.Marshal(&iptOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, buf)
}
// unmarshal implements matchMaker.unmarshal.
@@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
- binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 80f8c6430..e94aceb92 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,11 +15,12 @@
package netfilter
import (
+ "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +36,11 @@ const ErrorTargetName = "ERROR"
// change the destination port and/or IP for packets.
const RedirectTargetName = "REDIRECT"
+// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should
+// be reached for only NAT table. These targets will change the source port
+// and/or IP for packets.
+const SNATTargetName = "SNAT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -59,6 +65,13 @@ func init() {
registerTargetMaker(&nfNATTargetMaker{
NetworkProtocol: header.IPv6ProtocolNumber,
})
+
+ registerTargetMaker(&snatTargetMakerV4{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&snatTargetMakerV6{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
}
// The stack package provides some basic, useful targets for us. The following
@@ -131,6 +144,17 @@ func (rt *redirectTarget) id() targetID {
}
}
+type snatTarget struct {
+ stack.SNATTarget
+}
+
+func (st *snatTarget) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
@@ -166,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte {
Verdict: verdict,
}
- ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -176,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
var standardTarget linux.XTStandardTarget
- buf = buf[:linux.SizeOfXTStandardTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget)
+ standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()])
if standardTarget.Verdict < 0 {
// A Verdict < 0 indicates a non-jump verdict.
@@ -222,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte {
copy(xt.Name[:], errorName)
copy(xt.Target.Name[:], ErrorTargetName)
- ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -233,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt)
+ errTgt.UnmarshalUnsafe(buf)
// Error targets are used in 2 cases:
// * An actual error case. These rules have an error named
@@ -276,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte {
}
copy(xt.Target.Name[:], RedirectTargetName)
- ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -297,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &rt)
+ rt.UnmarshalUnsafe(buf)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
target := redirectTarget{RedirectTarget: stack.RedirectTarget{
@@ -336,12 +356,13 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return &target, nil
}
+// +marshal
type nfNATTarget struct {
Target linux.XTEntryTarget
Range linux.NFNATRange
}
-const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
@@ -358,7 +379,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
- TargetSize: nfNATMarhsalledSize,
+ TargetSize: nfNATMarshalledSize,
},
Range: linux.NFNATRange{
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
@@ -371,12 +392,11 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
- ret := make([]byte, 0, nfNATMarhsalledSize)
- return binary.Marshal(ret, hostarch.ByteOrder, nt)
+ return marshal.Marshal(&nt)
}
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
- if size := nfNATMarhsalledSize; len(buf) < size {
+ if size := nfNATMarshalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
}
@@ -387,8 +407,8 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
- binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
+ natRange.UnmarshalUnsafe(buf)
// We don't support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -418,6 +438,159 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return &target, nil
}
+type snatTargetMakerV4 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV4) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
+func (*snatTargetMakerV4) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ // This is a snat target named snat.
+ xt := linux.XTSNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTSNATTarget,
+ },
+ }
+ copy(xt.Target.Name[:], SNATTargetName)
+
+ xt.NfRange.RangeSize = 1
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(st.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
+ copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
+ copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
+ return marshal.Marshal(&xt)
+}
+
+func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if len(buf) < linux.SizeOfXTSNATTarget {
+ nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV4: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var st linux.XTSNATTarget
+ buf = buf[:linux.SizeOfXTSNATTarget]
+ st.UnmarshalUnsafe(buf)
+
+ // Copy linux.XTSNATTarget to stack.SNATTarget.
+ target := snatTarget{SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
+
+ // RangeSize should be 1.
+ nfRange := st.NfRange
+ if nfRange.RangeSize != 1 {
+ nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port,
+ // choose one automatically.
+ if nfRange.RangeIPV4.MinPort == 0 {
+ nflog("snatTargetMakerV4: snat target needs to specify a non-zero port")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
+}
+
+type snatTargetMakerV6 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV6) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ revision: 1,
+ }
+}
+
+func (*snatTargetMakerV6) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ nt := nfNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: nfNATMarshalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED,
+ },
+ }
+ copy(nt.Target.Name[:], SNATTargetName)
+ copy(nt.Range.MinAddr[:], st.Addr)
+ copy(nt.Range.MaxAddr[:], st.Addr)
+ nt.Range.MinProto = htons(st.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ return marshal.Marshal(&nt)
+}
+
+func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if size := nfNATMarshalledSize; len(buf) < size {
+ nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV6: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
+ natRange.UnmarshalUnsafe(buf)
+
+ // TODO(gvisor.dev/issue/5689): Support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("snatTargetMakerV6: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("snatTargetMakerV6: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags.
+ if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := snatTarget{
+ SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ },
+ }
+
+ return &target, nil
+}
+
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
@@ -453,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T
return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
- buf := optVal[:linux.SizeOfXTEntryTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &target)
+ target.UnmarshalUnsafe(optVal[:target.SizeBytes()])
return unmarshalTarget(target, filter, optVal)
}
@@ -480,7 +652,7 @@ func (jt *JumpTarget) id() targetID {
}
// Action implements stack.Target.Action.
-func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 69557f515..95bb9826e 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte {
DestinationPortStart: matcher.destinationPortStart,
DestinationPortEnd: matcher.destinationPortEnd,
}
- buf := make([]byte, 0, linux.SizeOfXTTCP)
- return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp))
+ return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.XTTCP
- binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
nflog("parseMatchers: parsed XTTCP: %+v", matchData)
if matchData.Option != 0 ||
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 6a60e6bd6..fb8be27e6 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte {
DestinationPortStart: matcher.destinationPortStart,
DestinationPortEnd: matcher.destinationPortEnd,
}
- buf := make([]byte, 0, linux.SizeOfXTUDP)
- return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp))
+ return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may exceed what's
// strictly necessary to hold matchData.
var matchData linux.XTUDP
- binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
nflog("parseMatchers: parsed XTUDP: %+v", matchData)
if matchData.InverseFlags != 0 {
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 171b95c63..64cd263da 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -14,7 +14,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
@@ -50,5 +50,7 @@ go_test(
deps = [
":netlink",
"//pkg/abi/linux",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index ab0e68af7..80385bfdc 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -19,15 +19,17 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
)
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return binary.AlignUp(length, align) - length
+ return bits.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -42,7 +44,7 @@ type Message struct {
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
hdr: hdr,
- buf: binary.Marshal(nil, hostarch.ByteOrder, hdr),
+ buf: marshal.Marshal(&hdr),
}
}
@@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
return
}
var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
// Msg portion.
totalMsgLen := int(hdr.Length)
@@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader {
// GetData unmarshals the payload message header from this netlink message, and
// returns the attributes portion.
-func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) {
b := BytesView(m.buf)
_, ok := b.Extract(linux.NetlinkMessageHeaderSize)
@@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
return nil, false
}
- size := int(binary.Size(msg))
+ size := msg.SizeBytes()
msgBytes, ok := b.Extract(size)
if !ok {
return nil, false
}
- binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg)
+ msg.UnmarshalUnsafe(msgBytes)
numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
// Linux permits the last message not being aligned, just consume all of it.
@@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) {
}
// Put serializes v into the message.
-func (m *Message) Put(v interface{}) {
- m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v)
+func (m *Message) Put(v marshal.Marshallable) {
+ m.buf = append(m.buf, marshal.Marshal(v)...)
}
// PutAttr adds v to the message as a netlink attribute.
//
// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
-// binary.Size(v) fits in math.MaxUint16 bytes.
-func (m *Message) PutAttr(atype uint16, v interface{}) {
- l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+// v.SizeBytes()) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) {
+ l := linux.NetlinkAttrHeaderSize + v.SizeBytes()
if l > math.MaxUint16 {
panic(fmt.Sprintf("attribute too large: %d", l))
}
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
m.Put(v)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
// PutAttrString adds s to the message as a netlink attribute.
func (m *Message) PutAttrString(atype uint16, s string) {
l := linux.NetlinkAttrHeaderSize + len(s) + 1
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
// String + NUL-termination.
- m.Put([]byte(s))
+ m.Put(primitive.AsByteSlice([]byte(s)))
m.putZeros(1)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest
if !ok {
return
}
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
if !ok {
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index ef13d9386..968968469 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -20,13 +20,31 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
type dummyNetlinkMsg struct {
+ marshal.StubMarshallable
Foo uint16
}
+func (*dummyNetlinkMsg) SizeBytes() int {
+ return 2
+}
+
+func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
+ p := primitive.Uint16(m.Foo)
+ p.MarshalUnsafe(dst)
+}
+
+func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
+ var p primitive.Uint16
+ p.UnmarshalUnsafe(src)
+ m.Foo = uint16(p)
+}
+
func TestParseMessage(t *testing.T) {
tests := []struct {
desc string
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 744fc74f4..c6c04b4e3 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -11,6 +11,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 5a2255db3..86f6419dc 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
Type: linux.RTM_NEWLINK,
})
- m.Put(linux.InterfaceInfoMessage{
+ m.Put(&linux.InterfaceInfoMessage{
Family: linux.AF_UNSPEC,
Type: i.DeviceType,
Index: idx,
@@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
})
m.PutAttrString(linux.IFLA_IFNAME, i.Name)
- m.PutAttr(linux.IFLA_MTU, i.MTU)
+ m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU))
mac := make([]byte, 6)
brd := mac
@@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
mac = i.Addr
brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
}
- m.PutAttr(linux.IFLA_ADDRESS, mac)
- m.PutAttr(linux.IFLA_BROADCAST, brd)
+ m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac))
+ m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd))
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl
Type: linux.RTM_NEWADDR,
})
- m.Put(linux.InterfaceAddrMessage{
+ m.Put(&linux.InterfaceAddrMessage{
Family: a.Family,
PrefixLen: a.PrefixLen,
Index: uint32(id),
})
- m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
- m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+ addr := primitive.ByteSlice([]byte(a.Addr))
+ m.PutAttr(linux.IFA_LOCAL, &addr)
+ m.PutAttr(linux.IFA_ADDRESS, &addr)
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Type: linux.RTM_NEWROUTE,
})
- m.Put(linux.RouteMessage{
+ m.Put(&linux.RouteMessage{
Family: rt.Family,
DstLen: rt.DstLen,
SrcLen: rt.SrcLen,
@@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Flags: rt.Flags,
})
- m.PutAttr(254, []byte{123})
+ m.PutAttr(254, primitive.AsByteSlice([]byte{123}))
if rt.DstLen > 0 {
- m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr))
}
if rt.SrcLen > 0 {
- m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr))
}
if rt.OutputInterface != 0 {
- m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface))
}
if len(rt.GatewayAddr) > 0 {
- m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr))
}
// TODO(gvisor.dev/issue/578): There are many more attributes.
@@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
hdr := msg.Header()
// All messages start with a 1 byte protocol family.
- var family uint8
+ var family primitive.Uint8
if _, ok := msg.GetData(&family); !ok {
// Linux ignores messages missing the protocol family. See
// net/core/rtnetlink.c:rtnetlink_rcv_msg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 30c297149..d75a2879f 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -20,7 +20,6 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa)
+ sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
@@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- sendBufferSizeP := primitive.Int32(s.sendBufferSize)
- return &sendBufferSizeP, nil
+ return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- recvBufferSizeP := primitive.Int32(math.MaxInt32)
- return &recvBufferSizeP, nil
+ return primitive.AllocateInt32(math.MaxInt32), nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// GetPeerName implements socket.Socket.GetPeerName.
@@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
// must be the kernel.
PortID: 0,
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// RecvMsg implements socket.Socket.RecvMsg.
@@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Family: linux.AF_NETLINK,
PortID: 0,
}
- fromLen := uint32(binary.Size(from))
+ fromLen := uint32(from.SizeBytes())
trunc := flags&linux.MSG_TRUNC != 0
@@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys
})
// Add the dump_done_errno payload.
- m.Put(int64(0))
+ m.Put(primitive.AllocateInt64(0))
_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
@@ -658,7 +655,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: int32(-err.ToLinux().Number()),
Header: hdr,
})
@@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: 0,
Header: hdr,
})
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 0b39a5b67..9561b7c25 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -19,7 +19,6 @@ go_library(
],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed6572bab..60ef33360 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "encoding/binary"
"fmt"
"io"
"io/ioutil"
@@ -35,7 +36,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
@@ -199,6 +199,13 @@ var Metrics = tcpip.Stats{
OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."),
OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."),
OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."),
+ Forwarding: tcpip.IPForwardingStats{
+ Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."),
+ ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."),
+ LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."),
+ LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."),
+ Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."),
+ },
},
ARP: tcpip.ARPStats{
PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."),
@@ -242,6 +249,7 @@ var Metrics = tcpip.Stats{
FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -374,9 +382,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
}), nil
}
-var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
-var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
-var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
+var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
+var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
+var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -612,7 +620,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < sockAddrLinkSize {
return syserr.ErrInvalidArgument
}
- binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
+ a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
if a.Protocol != uint16(s.protocol) {
return syserr.ErrInvalidArgument
@@ -885,10 +893,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ size := ep.SocketOptions().GetReceiveBufferSize()
if size > math.MaxInt32 {
size = math.MaxInt32
@@ -1314,7 +1319,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &v, nil
case linux.IP6T_ORIGINAL_DST:
- if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
}
@@ -1511,7 +1516,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return &v, nil
case linux.SO_ORIGINAL_DST:
- if outLen < int(binary.Size(linux.SockAddrInet{})) {
+ if outLen < sockAddrInetSize {
return nil, syserr.ErrInvalidArgument
}
@@ -1661,7 +1666,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := hostarch.ByteOrder.Uint32(optVal)
- ep.SocketOptions().SetSendBufferSize(int64(v), true)
+ ep.SocketOptions().SetSendBufferSize(int64(v), true /* notify */)
return nil
case linux.SO_RCVBUF:
@@ -1670,7 +1675,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := hostarch.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
+ ep.SocketOptions().SetReceiveBufferSize(int64(v), true /* notify */)
+ return nil
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1743,7 +1749,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1756,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1792,7 +1798,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Linger
- binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfLinger])
+
+ if v != (linux.Linger{}) {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: v.OnOff != 0,
@@ -2091,9 +2101,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
var (
- inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
- inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
- inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{}))
+ inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes()
+ inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes()
+ inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes()
)
// copyInMulticastRequest copies in a variable-size multicast request. The
@@ -2118,12 +2128,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
if len(optVal) >= inetMulticastRequestWithNICSize {
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req)
+ req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize])
return req, nil
}
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest)
+ req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize])
return req, nil
}
@@ -2133,7 +2143,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse
}
var req linux.Inet6MulticastRequest
- binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req)
+ req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize])
return req, nil
}
@@ -3102,8 +3112,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
continue
}
// Populate ifr.ifr_netmask (type sockaddr).
- hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
- hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET))
+ hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0)
var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
// Netmask is expected to be returned as a big endian
// value.
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 4c3d48096..9e56487a6 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -24,7 +24,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
switch family {
case unix.AF_INET:
var addr linux.SockAddrInet
- binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_INET6:
var addr linux.SockAddrInet6
- binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_UNIX:
var addr linux.SockAddrUnix
- binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_NETLINK:
var addr linux.SockAddrNetlink
- binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
default:
panic(fmt.Sprintf("Unsupported socket family %v", family))
@@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrInetSize])
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrInet6Size])
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrLinkSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrLinkSize])
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 159b8f90f..33f9aeb06 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -130,7 +130,8 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
}
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -175,8 +176,9 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
return ep
}
@@ -299,8 +301,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
- ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits)
+ ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize}
readQueue.InitRefs()
@@ -343,11 +346,11 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
return nil
default:
- // Busy; return ECONNREFUSED per spec.
+ // Busy; return EAGAIN per spec.
ne.Close(ctx)
e.Unlock()
ce.Unlock()
- return syserr.ErrConnectionRefused
+ return syserr.ErrTryAgain
}
}
@@ -366,6 +369,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
// to reflect this endpoint's send buffer size.
if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() {
e.ops.SetSendBufferSize(bufSz, false /* notify */)
+ e.ops.SetReceiveBufferSize(bufSz, false /* notify */)
}
}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
index 590b0bd01..b20334d4f 100644
--- a/pkg/sentry/socket/unix/transport/connectioned_state.go
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -54,5 +54,5 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd
// afterLoad is invoked by stateify.
func (e *connectionedEndpoint) afterLoad() {
- e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index d0df28b59..61338728a 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -45,7 +45,8 @@ func NewConnectionless(ctx context.Context) Endpoint {
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go
index 2ef337ec8..1bb71baf7 100644
--- a/pkg/sentry/socket/unix/transport/connectionless_state.go
+++ b/pkg/sentry/socket/unix/transport/connectionless_state.go
@@ -16,5 +16,5 @@ package transport
// afterLoad is invoked by stateify.
func (e *connectionlessEndpoint) afterLoad() {
- e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 0c5f5ab42..837ab4fde 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -868,11 +868,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
+ log.Warningf("Unsupported socket option: %d", opt)
return nil
}
@@ -905,19 +901,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
return int(v), nil
- case tcpip.ReceiveBufferSizeOption:
- e.Lock()
- if e.receiver == nil {
- e.Unlock()
- return -1, &tcpip.ErrNotConnected{}
- }
- v := e.receiver.RecvMaxQueueSize()
- e.Unlock()
- if v < 0 {
- return -1, &tcpip.ErrQueueSizeNotSupported{}
- }
- return int(v), nil
-
default:
log.Warningf("Unsupported socket option: %d", opt)
return -1, &tcpip.ErrUnknownProtocolOption{}
@@ -1029,3 +1012,15 @@ func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption {
Max: maxBufferSize,
}
}
+
+// getReceiveBufferLimits implements tcpip.GetReceiveBufferLimits.
+//
+// We define min, max and default values for unix socket implementation. Unix
+// sockets do not use receive buffer.
+func getReceiveBufferLimits(tcpip.StackHandler) tcpip.ReceiveBufferSizeOption {
+ return tcpip.ReceiveBufferSizeOption{
+ Min: minimumBufferSize,
+ Default: defaultBufferSize,
+ Max: maxBufferSize,
+ }
+}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 2ebd77f82..1fbbd133c 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -25,7 +25,6 @@ go_library(
":strace_go_proto",
"//pkg/abi",
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bits",
"//pkg/eventchannel",
"//pkg/hostarch",
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index 71b92eaee..d66befe81 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{
433: makeSyscallInfo("fspick", FD, Path, Hex),
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
+ 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet),
}
func init() {
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index bd7361a52..1a2d7d75f 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{
433: makeSyscallInfo("fspick", FD, Path, Hex),
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
+ 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet),
}
func init() {
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index e5b7f9b96..f4aab25b0 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -20,14 +20,13 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// SocketFamily are the possible socket(2) families.
@@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{
linux.SO_TIMESTAMP: "SO_TIMESTAMP",
}
+func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights {
+ count := len(src) / linux.SizeOfControlMessageRight
+ cmr := make(linux.ControlMessageRights, count)
+ for i, _ := range cmr {
+ cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:]))
+ }
+ return cmr
+}
+
func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string {
if length > maxBytes {
return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length)
@@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
+ h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
var skipData bool
level := "SOL_SOCKET"
@@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
-
- numRights := rightsSize / linux.SizeOfControlMessageRight
- fds := make(linux.ControlMessageRights, numRights)
- binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds)
-
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
+ fds := unmarshalControlMessageRights(buf[i : i+rightsSize])
rights := make([]string, 0, len(fds))
for _, fd := range fds {
rights = append(rights, fmt.Sprint(fd))
@@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
+ creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
@@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var tv linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv)
+ tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
@@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index e115683f8..3b4d79889 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
}
// WaitEpoll implements the epoll_wait(2) linux syscall.
-func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) {
+func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) {
// Get epoll from the file descriptor.
epollfile := t.GetFile(fd)
if epollfile == nil {
@@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve
// Try to read events and return right away if we got them or if the
// caller requested a non-blocking "wait".
r := e.ReadEvents(max)
- if len(r) != 0 || timeout == 0 {
+ if len(r) != 0 || timeoutInNanos == 0 {
return r, nil
}
@@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve
// and register with the epoll object for readability events.
var haveDeadline bool
var deadline ktime.Time
- if timeout > 0 {
- timeoutDur := time.Duration(timeout) * time.Millisecond
+ if timeoutInNanos > 0 {
+ timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond
deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
haveDeadline = true
}
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index efec93f73..6eabfd219 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -29,10 +29,17 @@ import (
)
var (
- partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.")
- partialResultOnce sync.Once
+ partialResultOnce sync.Once
)
+// incrementPartialResultMetric increments PartialResultMetric by calling
+// Increment(). This is added as the func Do() which is called below requires
+// us to pass a function which does not take any arguments, whereas Increment()
+// takes a variadic number of arguments.
+func incrementPartialResultMetric() {
+ metric.WeirdnessMetric.Increment("partial_result")
+}
+
// HandleIOErrorVFS2 handles special error cases for partial results. For some
// errors, we may consume the error and return only the partial read/write.
//
@@ -48,7 +55,7 @@ func HandleIOErrorVFS2(ctx context.Context, partialResult bool, ioerr, intr erro
root := vfs.RootFromContext(ctx)
name, _ := fs.PathnameWithDeleted(ctx, root, f.VirtualDentry())
log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name)
- partialResultOnce.Do(partialResultMetric.Increment)
+ partialResultOnce.Do(incrementPartialResultMetric)
}
return nil
}
@@ -66,7 +73,7 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o
// An unknown error is encountered with a partial read/write.
name, _ := f.Dirent.FullName(nil /* ignore chroot */)
log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations)
- partialResultOnce.Do(partialResultMetric.Increment)
+ partialResultOnce.Do(incrementPartialResultMetric)
}
return nil
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 2d2212605..090c5ffcb 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{
433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{
0xffffffffff600000: 96, // vsyscall gettimeofday(2)
@@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{
433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{},
Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 7f460d30b..69cbc98d0 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
@@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// EpollWait implements the epoll_wait(2) linux syscall.
-func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- epfd := args[0].Int()
- eventsAddr := args[1].Pointer()
- maxEvents := int(args[2].Int())
- timeout := int(args[3].Int())
-
- r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout)
+func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
+ r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos)
if err != nil {
return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
}
@@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
return uintptr(len(r)), nil, nil
+
+}
+
+// EpollWait implements the epoll_wait(2) linux syscall.
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ // Convert milliseconds to nanoseconds.
+ timeoutInNanos := int64(args[3].Int()) * 1000000
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
}
// EpollPwait implements the epoll_pwait(2) linux syscall.
@@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+// EpollPwait2 implements the epoll_pwait(2) linux syscall.
+func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutPtr := args[3].Pointer()
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+ haveTimeout := timeoutPtr != 0
+
+ var timeoutInNanos int64 = -1
+ if haveTimeout {
+ timeout, err := copyTimespecIn(t, timeoutPtr)
+ if err != nil {
+ return 0, nil, err
+ }
+ timeoutInNanos = timeout.ToNsec()
+
+ }
+
+ if maskAddr != 0 {
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
+}
+
// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9bdf6d3d8..e07917613 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -35,12 +35,6 @@ import (
// LINT.IfChange
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
-
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -367,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFile(fd)
@@ -382,14 +379,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
backlog = maxListenBacklog
}
+ // Accept one more than the configured listen backlog to keep in parity with
+ // Linux. Ref, because of missing equality check here:
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
+ //
+ // In case of unix domain sockets, the following check
+ // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293
+ // will allow 1 connect through since it checks for a receive queue len >
+ // backlog and not >=.
+ backlog++
+
return 0, nil, s.Listen(t, int(backlog)).ToError()
}
@@ -457,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(v.SizeBytes())
- if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index b980aa43e..047d955b6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// EpollWait implements Linux syscall epoll_wait(2).
-func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- epfd := args[0].Int()
- eventsAddr := args[1].Pointer()
- maxEvents := int(args[2].Int())
- timeout := int(args[3].Int())
-
+func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
return 0, nil, syserror.EINVAL
@@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
return 0, nil, err
}
- if timeout == 0 {
+ if timeoutInNanos == 0 {
return 0, nil, nil
}
// In the first iteration of this loop, register with the epoll
@@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
defer epfile.EventUnregister(&w)
} else {
// Set up the timer if a timeout was specified.
- if timeout > 0 && !haveDeadline {
- timeoutDur := time.Duration(timeout) * time.Millisecond
+ if timeoutInNanos > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond
deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
haveDeadline = true
}
@@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
}
}
+
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutInNanos := int64(args[3].Int()) * 1000000
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
}
// EpollPwait implements Linux syscall epoll_pwait(2).
@@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// EpollPwait2 implements Linux syscall epoll_pwait(2).
+func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutPtr := args[3].Pointer()
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+ haveTimeout := timeoutPtr != 0
+
+ var timeoutInNanos int64 = -1
+ if haveTimeout {
+ var timeout linux.Timespec
+ if _, err := timeout.CopyIn(t, timeoutPtr); err != nil {
+ return 0, nil, err
+ }
+ timeoutInNanos = timeout.ToNsec()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index a87a66146..69f69e3af 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -35,12 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
)
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
-
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -371,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFileVFS2(fd)
@@ -386,14 +383,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
backlog = maxListenBacklog
}
+ // Accept one more than the configured listen backlog to keep in parity with
+ // Linux. Ref, because of missing equality check here:
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
+ //
+ // In case of unix domain sockets, the following check
+ // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293
+ // will allow 1 connect through since it checks for a receive queue len >
+ // backlog and not >=.
+ backlog++
+
return 0, nil, s.Listen(t, int(backlog)).ToError()
}
@@ -461,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(v.SizeBytes())
- if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index c50fd97eb..0fc81e694 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -159,6 +159,7 @@ func Override() {
s.Table[327] = syscalls.Supported("preadv2", Preadv2)
s.Table[328] = syscalls.Supported("pwritev2", Pwritev2)
s.Table[332] = syscalls.Supported("statx", Statx)
+ s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2)
s.Init()
// Override ARM64.
@@ -269,6 +270,7 @@ func Override() {
s.Table[286] = syscalls.Supported("preadv2", Preadv2)
s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
s.Table[291] = syscalls.Supported("statx", Statx)
+ s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2)
s.Init()
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 87d8687ce..1f617ca8f 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -32,6 +32,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/metric",
"//pkg/sync",
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
index f9a93115d..39bf1e0de 100644
--- a/pkg/sentry/time/calibrated_clock.go
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -25,11 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// fallbackMetric tracks failed updates. It is not sync, as it is not critical
-// that all occurrences are captured and CalibratedClock may fallback many
-// times.
-var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update")
-
// CalibratedClock implements a clock that tracks a reference clock.
//
// Users should call Update at regular intervals of around approxUpdateInterval
@@ -102,7 +97,7 @@ func (c *CalibratedClock) resetLocked(str string, v ...interface{}) {
c.Warningf(str+" Resetting clock; time may jump.", v...)
c.ready = false
c.ref.Reset()
- fallbackMetric.Increment()
+ metric.WeirdnessMetric.Increment("time_fallback")
}
// updateParams updates the timekeeping parameters based on the passed
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index f612a71b2..176bcc242 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -524,7 +524,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St
Start: fd.vd,
})
stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, err
}
return fd.impl.Stat(ctx, opts)
@@ -539,7 +539,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.SetStat(ctx, opts)
@@ -555,7 +555,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
Start: fd.vd,
})
statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, err
}
return fd.impl.StatFS(ctx)
@@ -701,7 +701,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string
Start: fd.vd,
})
names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return names, err
}
names, err := fd.impl.ListXattr(ctx, size)
@@ -730,7 +730,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions)
Start: fd.vd,
})
val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return val, err
}
return fd.impl.GetXattr(ctx, *opts)
@@ -746,7 +746,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions)
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.SetXattr(ctx, *opts)
@@ -762,7 +762,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error {
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.RemoveXattr(ctx, name)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 1556b41a3..b87d9690a 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface {
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
// with Linux's fs/seq_file.c:single_open().
//
+// If data additionally implements WritableDynamicBytesSource, writes are
+// dispatched to the implementer. The source data is not automatically modified.
+//
// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
// use.
//
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 922f9e697..82fd382c2 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -826,6 +826,9 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
if mnt.Flags.NoExec {
opts += ",noexec"
}
+ if mopts := mnt.fs.Impl().MountOptions(); mopts != "" {
+ opts += "," + mopts
+ }
// Format:
// <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
@@ -970,17 +973,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string {
opts += "," + mopts
}
- // NOTE(b/147673608): If the mount is a cgroup, we also need to include
- // the cgroup name in the options. For now we just read that from the
- // path.
+ // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also
+ // need to include the cgroup name in the options. For now we just read that
+ // from the path. Note that this is only possible when "cgroup" isn't
+ // registered as a valid filesystem type.
//
- // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
- // should get this value from the cgroup itself, and not rely on the
- // path.
+ // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we
+ // should remove this.
+ if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount {
+ // Real cgroupfs available.
+ return opts
+ }
if mnt.fs.FilesystemType().Name() == "cgroup" {
splitPath := strings.Split(mountPath, "/")
cgroupType := splitPath[len(splitPath)-1]
opts += "," + cgroupType
}
+
return opts
}
diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go
index 39fbac987..47848c76b 100644
--- a/pkg/sentry/vfs/opath.go
+++ b/pkg/sentry/vfs/opath.go
@@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err
Start: fd.vfsfd.vd,
})
stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, err
}
@@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) {
Start: fd.vfsfd.vd,
})
statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, err
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index e4fd55012..97b898aba 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -44,13 +44,10 @@ type ResolvingPath struct {
start *Dentry
pit fspath.Iterator
- flags uint16
- mustBeDir bool // final file must be a directory?
- mustBeDirOrig bool
- symlinks uint8 // number of symlinks traversed
- symlinksOrig uint8
- curPart uint8 // index into parts
- numOrigParts uint8
+ flags uint16
+ mustBeDir bool // final file must be a directory?
+ symlinks uint8 // number of symlinks traversed
+ curPart uint8 // index into parts
creds *auth.Credentials
@@ -60,14 +57,9 @@ type ResolvingPath struct {
nextStart *Dentry // ref held if not nil
absSymlinkTarget fspath.Path
- // ResolvingPath must track up to two relative paths: the "current"
- // relative path, which is updated whenever a relative symlink is
- // encountered, and the "original" relative path, which is updated from the
- // current relative path by handleError() when resolution must change
- // filesystems (due to reaching a mount boundary or absolute symlink) and
- // overwrites the current relative path when Restart() is called.
- parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
- origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
+ // ResolvingPath tracks relative paths, which is updated whenever a relative
+ // symlink is encountered.
+ parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
}
const (
@@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{
},
}
+// getResolvingPath gets a new ResolvingPath from the pool. Caller must call
+// ResolvingPath.Release() when done.
func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath {
rp := resolvingPathPool.Get().(*ResolvingPath)
rp.vfs = vfs
@@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat
rp.flags |= rpflagsFollowFinalSymlink
}
rp.mustBeDir = pop.Path.Dir
- rp.mustBeDirOrig = pop.Path.Dir
rp.symlinks = 0
rp.curPart = 0
- rp.numOrigParts = 1
rp.creds = creds
rp.parts[0] = pop.Path.Begin
- rp.origParts[0] = pop.Path.Begin
return rp
}
-func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) {
+// Copy creates another ResolvingPath with the same state as the original.
+// Copies are independent, using the copy does not change the original and
+// vice-versa.
+//
+// Caller must call Resease() when done.
+func (rp *ResolvingPath) Copy() *ResolvingPath {
+ copy := resolvingPathPool.Get().(*ResolvingPath)
+ *copy = *rp // All fields all shallow copiable.
+
+ // Take extra reference for the copy if the original had them.
+ if copy.flags&rpflagsHaveStartRef != 0 {
+ copy.start.IncRef()
+ }
+ if copy.flags&rpflagsHaveMountRef != 0 {
+ copy.mount.IncRef()
+ }
+ // Reset error state.
+ copy.nextStart = nil
+ copy.nextMount = nil
+ return copy
+}
+
+// Release decrements references if needed and returns the object to the pool.
+func (rp *ResolvingPath) Release(ctx context.Context) {
rp.root = VirtualDentry{}
rp.decRefStartAndMount(ctx)
rp.mount = nil
@@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() {
}
}
-// Restart resets the stream of path components represented by rp to its state
-// on entry to the current FilesystemImpl method.
-func (rp *ResolvingPath) Restart(ctx context.Context) {
- rp.pit = rp.origParts[rp.numOrigParts-1]
- rp.mustBeDir = rp.mustBeDirOrig
- rp.symlinks = rp.symlinksOrig
- rp.curPart = rp.numOrigParts - 1
- copy(rp.parts[:], rp.origParts[:rp.numOrigParts])
- rp.releaseErrorState(ctx)
-}
-
-func (rp *ResolvingPath) relpathCommit() {
- rp.mustBeDirOrig = rp.mustBeDir
- rp.symlinksOrig = rp.symlinks
- rp.numOrigParts = rp.curPart + 1
- copy(rp.origParts[:rp.curPart], rp.parts[:])
- rp.origParts[rp.curPart] = rp.pit
-}
-
// CheckRoot is called before resolving the parent of the Dentry d. If the
// Dentry is contextually a VFS root, such that path resolution should treat
// d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the
@@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef
rp.nextMount = nil
rp.nextStart = nil
- // Commit the previous FileystemImpl's progress through the relative
- // path. (Don't consume the path component that caused us to traverse
+ // Don't consume the path component that caused us to traverse
// through the mount root - i.e. the ".." - because we still need to
- // resolve the mount point's parent in the new FilesystemImpl.)
- rp.relpathCommit()
+ // resolve the mount point's parent in the new FilesystemImpl.
+ //
// Restart path resolution on the new Mount. Don't bother calling
// rp.releaseErrorState() since we already set nextMount and nextStart
// to nil above.
@@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.nextMount = nil
// Consume the path component that represented the mount point.
rp.Advance()
- // Commit the previous FilesystemImpl's progress through the relative
- // path.
- rp.relpathCommit()
// Restart path resolution on the new Mount.
rp.releaseErrorState(ctx)
return true
@@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.Advance()
// Prepend the symlink target to the relative path.
rp.relpathPrepend(rp.absSymlinkTarget)
- // Commit the previous FilesystemImpl's progress through the relative
- // path, including the symlink target we just prepended.
- rp.relpathCommit()
// Restart path resolution on the new Mount.
rp.releaseErrorState(ctx)
return true
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 00f1847d8..87fdcf403 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede
dentry: d,
}
rp.mount.IncRef()
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return vd, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return VirtualDentry{}, err
}
}
@@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
rp.mount.IncRef()
name := rp.Component()
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return parentVD, name, nil
}
if checkInvariants {
@@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return VirtualDentry{}, "", err
}
}
@@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
for {
err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldVD.DecRef(ctx)
return nil
}
@@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldVD.DecRef(ctx)
return err
}
@@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
rp := vfs.getResolvingPath(creds, pop)
if opts.Flags&linux.O_DIRECTORY != 0 {
rp.mustBeDir = true
- rp.mustBeDirOrig = true
}
// Ignore O_PATH for verity, as verity performs extra operations on the fd for verification.
// The underlying filesystem that verity wraps opens the fd with O_PATH.
@@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
for {
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
if opts.FileExec {
if fd.Mount().Flags.NoExec {
@@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
return fd, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden
for {
target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return target, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return "", err
}
}
@@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldParentVD.DecRef(ctx)
return nil
}
@@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldParentVD.DecRef(ctx)
return err
}
@@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.RmdirAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential
for {
stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return linux.Statx{}, err
}
}
@@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti
for {
statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return linux.Statfs{}, err
}
}
@@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
for {
bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return bep, nil
}
if checkInvariants {
@@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede
for {
names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return names, nil
}
if err == syserror.ENOTSUP {
@@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
// subsystem to return security extended attributes, which by
// default don't exist.
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden
for {
val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return val, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return "", err
}
}
@@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden
for {
err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre
for {
err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
diff --git a/pkg/shim/utils/volumes.go b/pkg/shim/utils/volumes.go
index 52a428179..cdcb88229 100644
--- a/pkg/shim/utils/volumes.go
+++ b/pkg/shim/utils/volumes.go
@@ -91,11 +91,9 @@ func isVolumePath(volume, path string) (bool, error) {
// UpdateVolumeAnnotations add necessary OCI annotations for gvisor
// volume optimization.
func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
- var (
- uid string
- err error
- )
+ var uid string
if IsSandbox(s) {
+ var err error
uid, err = podUID(s)
if err != nil {
// Skip if we can't get pod UID, because this doesn't work
@@ -123,21 +121,18 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
} else {
// This is a container.
for i := range s.Mounts {
- // An error is returned for sandbox if source
- // annotation is not successfully applied, so
- // it is guaranteed that the source annotation
- // for sandbox has already been successfully
- // applied at this point.
+ // An error is returned for sandbox if source annotation is not
+ // successfully applied, so it is guaranteed that the source annotation
+ // for sandbox has already been successfully applied at this point.
//
- // The volume name is unique inside a pod, so
- // matching without podUID is fine here.
+ // The volume name is unique inside a pod, so matching without podUID
+ // is fine here.
//
- // TODO: Pass podUID down to shim for containers to do
- // more accurate matching.
+ // TODO: Pass podUID down to shim for containers to do more accurate
+ // matching.
if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes {
- // gVisor requires the container mount type to match
- // sandbox mount type.
- s.Mounts[i].Type = v
+ // Container mount type must match the sandbox's mount type.
+ changeMountType(&s.Mounts[i], v)
updated = true
}
}
@@ -153,3 +148,22 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
}
return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666)
}
+
+func changeMountType(m *specs.Mount, newType string) {
+ m.Type = newType
+
+ // OCI spec allows bind mounts to be specified in options only. So if new type
+ // is not bind, remove bind/rbind from options.
+ //
+ // "For bind mounts (when options include either bind or rbind), the type is
+ // a dummy, often "none" (not listed in /proc/filesystems)."
+ if newType != "bind" {
+ newOpts := make([]string, 0, len(m.Options))
+ for _, opt := range m.Options {
+ if opt != "rbind" && opt != "bind" {
+ newOpts = append(newOpts, opt)
+ }
+ }
+ m.Options = newOpts
+ }
+}
diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go
index 3e02c6151..b25c53c73 100644
--- a/pkg/shim/utils/volumes_test.go
+++ b/pkg/shim/utils/volumes_test.go
@@ -47,60 +47,60 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
}
for _, test := range []struct {
- desc string
+ name string
spec *specs.Spec
expected *specs.Spec
expectErr bool
expectUpdate bool
}{
{
- desc: "volume annotations for sandbox",
+ name: "volume annotations for sandbox",
spec: &specs.Spec{
Annotations: map[string]string{
- sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expected: &specs.Spec{
Annotations: map[string]string{
- sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
- "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
+ volumeKeyPrefix + testVolumeName + ".source": testVolumePath,
},
},
expectUpdate: true,
},
{
- desc: "volume annotations for sandbox with legacy log path",
+ name: "volume annotations for sandbox with legacy log path",
spec: &specs.Spec{
Annotations: map[string]string{
- sandboxLogDirAnnotation: testLegacyLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expected: &specs.Spec{
Annotations: map[string]string{
- sandboxLogDirAnnotation: testLegacyLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
- "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
+ volumeKeyPrefix + testVolumeName + ".source": testVolumePath,
},
},
expectUpdate: true,
},
{
- desc: "tmpfs: volume annotations for container",
+ name: "tmpfs: volume annotations for container",
spec: &specs.Spec{
Mounts: []specs.Mount{
{
@@ -117,10 +117,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expected: &specs.Spec{
@@ -139,16 +139,16 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expectUpdate: true,
},
{
- desc: "bind: volume annotations for container",
+ name: "bind: volume annotations for container",
spec: &specs.Spec{
Mounts: []specs.Mount{
{
@@ -159,10 +159,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "container",
+ volumeKeyPrefix + testVolumeName + ".type": "bind",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expected: &specs.Spec{
@@ -175,48 +175,48 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "container",
+ volumeKeyPrefix + testVolumeName + ".type": "bind",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expectUpdate: true,
},
{
- desc: "should not return error without pod log directory",
+ name: "should not return error without pod log directory",
spec: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
expected: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
- "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
- "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
},
},
},
{
- desc: "should return error if volume path does not exist",
+ name: "should return error if volume path does not exist",
spec: &specs.Spec{
Annotations: map[string]string{
- sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
- "dev.gvisor.spec.mount.notexist.share": "pod",
- "dev.gvisor.spec.mount.notexist.type": "tmpfs",
- "dev.gvisor.spec.mount.notexist.options": "ro",
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ volumeKeyPrefix + "notexist.share": "pod",
+ volumeKeyPrefix + "notexist.type": "tmpfs",
+ volumeKeyPrefix + "notexist.options": "ro",
},
},
expectErr: true,
},
{
- desc: "no volume annotations for sandbox",
+ name: "no volume annotations for sandbox",
spec: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
@@ -231,7 +231,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
{
- desc: "no volume annotations for container",
+ name: "no volume annotations for container",
spec: &specs.Spec{
Mounts: []specs.Mount{
{
@@ -271,8 +271,46 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
},
+ {
+ name: "bind options removed",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
+ volumeKeyPrefix + testVolumeName + ".source": testVolumePath,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/dst",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro", "bind", "rbind"},
+ },
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ volumeKeyPrefix + testVolumeName + ".share": "pod",
+ volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
+ volumeKeyPrefix + testVolumeName + ".options": "ro",
+ volumeKeyPrefix + testVolumeName + ".source": testVolumePath,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/dst",
+ Type: "tmpfs",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ },
+ expectUpdate: true,
+ },
} {
- t.Run(test.desc, func(t *testing.T) {
+ t.Run(test.name, func(t *testing.T) {
bundle, err := ioutil.TempDir(dir, "test-bundle")
if err != nil {
t.Fatalf("Create test bundle: %v", err)
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index d6c89c7e9..08d06e37b 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -7,7 +7,6 @@ go_library(
srcs = ["statefile.go"],
visibility = ["//:sandbox"],
deps = [
- "//pkg/binary",
"//pkg/compressio",
"//pkg/state/wire",
],
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
index bdfb800fb..d27c8c8a8 100644
--- a/pkg/state/statefile/statefile.go
+++ b/pkg/state/statefile/statefile.go
@@ -48,6 +48,7 @@ import (
"compress/flate"
"crypto/hmac"
"crypto/sha256"
+ "encoding/binary"
"encoding/json"
"fmt"
"hash"
@@ -55,7 +56,6 @@ import (
"strings"
"time"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/compressio"
"gvisor.dev/gvisor/pkg/state/wire"
)
@@ -90,6 +90,13 @@ type WriteCloser interface {
io.Closer
}
+func writeMetadataLen(w io.Writer, val uint64) error {
+ var buf [8]byte
+ binary.BigEndian.PutUint64(buf[:], val)
+ _, err := w.Write(buf[:])
+ return err
+}
+
// NewWriter returns a state data writer for a statefile.
//
// Note that the returned WriteCloser must be closed.
@@ -127,7 +134,7 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser
}
// Metadata length.
- if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ if err := writeMetadataLen(mw, uint64(len(b))); err != nil {
return nil, err
}
// Metadata bytes; io.MultiWriter will return a short write error if
@@ -158,6 +165,14 @@ func MetadataUnsafe(r io.Reader) (map[string]string, error) {
return metadata(r, nil)
}
+func readMetadataLen(r io.Reader) (uint64, error) {
+ var buf [8]byte
+ if _, err := io.ReadFull(r, buf[:]); err != nil {
+ return 0, err
+ }
+ return binary.BigEndian.Uint64(buf[:]), nil
+}
+
// metadata validates the magic header and reads out the metadata from a state
// data stream.
func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
@@ -183,7 +198,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
}
}()
- metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ metadataLen, err := readMetadataLen(r)
if err != nil {
return nil, err
}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index b2c5229e7..8b3a11c64 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -43,6 +43,7 @@ go_template(
],
deps = [
":sync",
+ "//pkg/gohacks",
],
)
diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go
index 82b676abf..9578c9c52 100644
--- a/pkg/sync/generic_seqatomic_unsafe.go
+++ b/pkg/sync/generic_seqatomic_unsafe.go
@@ -10,6 +10,7 @@ package seqatomic
import (
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -39,7 +40,7 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
// runtime.RaceDisable() doesn't actually stop the race detector, so it
// can't help us here. Instead, call runtime.memmove directly, which is
// not instrumented by the race detector.
- sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
// This is ~40% faster for short reads than going through memmove.
val = *ptr
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
index 158985709..39c766331 100644
--- a/pkg/sync/runtime_unsafe.go
+++ b/pkg/sync/runtime_unsafe.go
@@ -17,20 +17,6 @@ import (
"unsafe"
)
-// Note that go:linkname silently doesn't work if the local name is exported,
-// necessitating an indirection for exported functions.
-
-// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
-//
-//go:nosplit
-func Memmove(to, from unsafe.Pointer, n uintptr) {
- memmove(to, from, n)
-}
-
-//go:linkname memmove runtime.memmove
-//go:noescape
-func memmove(to, from unsafe.Pointer, n uintptr)
-
// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock);
// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g)
// is called. unlockf and its callees must be nosplit and norace, since stack
diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
index 5c38c783e..5f9164117 100644
--- a/pkg/sync/seqatomictest/BUILD
+++ b/pkg/sync/seqatomictest/BUILD
@@ -18,6 +18,7 @@ go_library(
name = "seqatomic",
srcs = ["seqatomic_int_unsafe.go"],
deps = [
+ "//pkg/gohacks",
"//pkg/sync",
],
)
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index f979d22f0..e96ba50ae 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,4 +1,5 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:deps.bzl", "deps_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -21,8 +22,9 @@ go_library(
"errors.go",
"sock_err_list.go",
"socketops.go",
+ "stdclock.go",
+ "stdclock_state.go",
"tcpip.go",
- "time_unsafe.go",
"timer.go",
],
visibility = ["//visibility:public"],
@@ -33,6 +35,36 @@ go_library(
],
)
+deps_test(
+ name = "netstack_deps_test",
+ allowed = [
+ "@com_github_google_btree//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
+ ],
+ allowed_prefixes = [
+ "//",
+ "@org_golang_x_sys//internal/unsafeheader",
+ ],
+ targets = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/packetsocket",
+ "//pkg/tcpip/link/qdisc/fifo",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ ],
+)
+
go_test(
name = "tcpip_test",
size = "small",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index fef065b05..12c39dfa3 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Error("Not a valid IPv4 packet")
}
- xsum := ipv4.CalculateChecksum()
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ if !ipv4.IsChecksumValid() {
+ t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
}
for _, f := range checkers {
@@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
- // Verify the checksum.
tcp := header.TCP(last.Payload())
- l := uint16(len(tcp))
-
- xsum := header.Checksum([]byte(first.SourceAddress()), 0)
- xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
- xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
- xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
- xsum = header.Checksum(tcp, xsum)
-
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ payload := tcp.Payload()
+ payloadChecksum := header.Checksum(payload, 0)
+ if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
+ t.Errorf("Bad checksum, got = %d", tcp.Checksum())
}
// Run the transport checkers.
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
index 52c22230e..33ff22a7b 100644
--- a/pkg/tcpip/hash/jenkins/jenkins.go
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 }
// Sum32 returns the hash value
func (s *Sum32) Sum32() uint32 {
- hash := *s
+ sCopy := *s
- hash += (hash << 3)
- hash ^= hash >> 11
- hash += hash << 15
+ sCopy += sCopy << 3
+ sCopy ^= sCopy >> 11
+ sCopy += sCopy << 15
- return uint32(hash)
+ return uint32(sCopy)
}
// Write adds more data to the running hash.
//
// It never returns an error.
func (s *Sum32) Write(data []byte) (int, error) {
- hash := *s
+ sCopy := *s
for _, b := range data {
- hash += Sum32(b)
- hash += hash << 10
- hash ^= hash >> 6
+ sCopy += Sum32(b)
+ sCopy += sCopy << 10
+ sCopy ^= sCopy >> 6
}
- *s = hash
+ *s = sCopy
return len(data), nil
}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 0bdc12d53..01240f5d0 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -52,6 +52,7 @@ go_test(
"//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -69,6 +70,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 3bc8b2b21..bf9ccbf1a 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
func TestIsValidUnicastEthernetAddress(t *testing.T) {
@@ -142,7 +143,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
}
func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
- addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a")
if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
index b6126d29a..575604928 100644
--- a/pkg/tcpip/header/igmp_test.go
+++ b/pkg/tcpip/header/igmp_test.go
@@ -18,8 +18,8 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
// TestIGMPHeader tests the functions within header.igmp
@@ -46,7 +46,7 @@ func TestIGMPHeader(t *testing.T) {
t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
}
- if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want {
+ if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want {
t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
}
@@ -71,7 +71,7 @@ func TestIGMPHeader(t *testing.T) {
t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
}
- groupAddress := tcpip.Address("\x04\x03\x02\x01")
+ groupAddress := testutil.MustParse4("4.3.2.1")
igmpHeader.SetGroupAddress(groupAddress)
if got := igmpHeader.GroupAddress(); got != groupAddress {
t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index f588311e0..2be21ec75 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -178,6 +178,26 @@ const (
IPv4FlagDontFragment
)
+// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined
+// by RFC 3927 section 1.
+var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
+// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as
+// defined by RFC 5771 section 4.
+var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPv4EmptySubnet is the empty IPv4 subnet.
var IPv4EmptySubnet = func() tcpip.Subnet {
subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
@@ -423,6 +443,44 @@ func (b IPv4) IsValid(pktSize int) bool {
return true
}
+// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4
+// link-local unicast address.
+func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalUnicastSubnet.Contains(addr)
+}
+
+// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4
+// link-local multicast address.
+func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalMulticastSubnet.Contains(addr)
+}
+
+// IsChecksumValid returns true iff the IPv4 header's checksum is valid.
+func (b IPv4) IsChecksumValid() bool {
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ return b.CalculateChecksum() == 0xffff
+}
+
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
// will be 1110 = 0xe0.
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
index 6475cd694..c02fe898b 100644
--- a/pkg/tcpip/header/ipv4_test.go
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) {
})
}
}
+
+func TestIsV4LinkLocalUnicastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xa9\xfe\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xa9\xfe\xff\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xa9\xfd\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xa9\xff\x00\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV4LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xe0\x00\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xe0\x00\x00\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xdf\xff\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xe0\x00\x01\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index f2403978c..c3a0407ac 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -98,12 +98,27 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- // IPv6AllRoutersMulticastAddress is a link-local multicast group that
- // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
+ // multicast group that all IPv6 routers MUST join, as per RFC 4291, section
+ // 2.8. Packets destined to this address will reach the router on an
+ // interface.
+ //
+ // The address is ff01::2.
+ IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers on a link.
//
// The address is ff02::2.
- IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers in a site.
+ //
+ // The address is ff05::2.
+ IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
// section 5:
@@ -142,11 +157,6 @@ const (
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
// within the byte holding the field, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeMask = 0xF
-
- // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
- // a multicast IPv6 address that indicates the address has link-local scope,
- // as per RFC 4291 section 2.7.
- ipv6LinkLocalMulticastScope = 2
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
return tcpip.Address(lladdrb[:])
}
-// IsV6LinkLocalAddress determines if the provided address is an IPv6
-// link-local address (fe80::/10).
-func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
+// link-local unicast address, as defined by RFC 4291 section 2.5.6.
+func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
if len(addr) != IPv6AddressSize {
return false
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
-// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
-// address.
+// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
+// address, as defined by RFC 4291 section 2.5.3.
func IsV6LoopbackAddress(addr tcpip.Address) bool {
return addr == IPv6Loopback
}
-// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
-// link-local multicast address.
+// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
+// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
- return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+ return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
}
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
@@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
case IsV6LinkLocalMulticastAddress(addr):
return LinkLocalScope, nil
- case IsV6LinkLocalAddress(addr):
+ case IsV6LinkLocalUnicastAddress(addr):
return LinkLocalScope, nil
default:
@@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address)
PrefixLen: IIDOffsetInIPv6Address * 8,
}
}
+
+// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
+// RFC 7346 section 2.
+type IPv6MulticastScope uint8
+
+// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
+//
+// +------+--------------------------+-------------------------+
+// | scop | NAME | REFERENCE |
+// +------+--------------------------+-------------------------+
+// | 0 | Reserved | [RFC4291], RFC 7346 |
+// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
+// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
+// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
+// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
+// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
+// | 6 | Unassigned | |
+// | 7 | Unassigned | |
+// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
+// | 9 | Unassigned | |
+// | A | Unassigned | |
+// | B | Unassigned | |
+// | C | Unassigned | |
+// | D | Unassigned | |
+// | E | Global scope | [RFC4291], RFC 7346 |
+// | F | Reserved | [RFC4291], RFC 7346 |
+// +------+--------------------------+-------------------------+
+const (
+ IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
+ IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
+ IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
+ IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
+ IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
+ IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
+ IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
+ IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
+ IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
+)
+
+// V6MulticastScope returns the scope of a multicast address.
+func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
+ return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index f10f446a6..89be84068 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -24,15 +24,17 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+var (
+ linkLocalAddr = testutil.MustParse6("fe80::1")
+ linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
+ uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
+ uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
+ globalAddr = testutil.MustParse6("a000::1")
)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
@@ -50,7 +52,7 @@ func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
}
func TestLinkLocalAddr(t *testing.T) {
- if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
+ if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want {
t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
}
}
@@ -252,7 +254,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
}
}
-func TestIsV6LinkLocalAddress(t *testing.T) {
+func TestIsV6LinkLocalUnicastAddress(t *testing.T) {
tests := []struct {
name string
addr tcpip.Address
@@ -287,8 +289,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
}
})
}
@@ -373,3 +375,83 @@ func TestSolicitedNodeAddr(t *testing.T) {
})
}
}
+
+func TestV6MulticastScope(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want header.IPv6MulticastScope
+ }{
+ {
+ addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6Reserved0MulticastScope,
+ },
+ {
+ addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6InterfaceLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6LinkLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6RealmLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6AdminLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6SiteLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(6),
+ },
+ {
+ addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(7),
+ },
+ {
+ addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6OrganizationLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(9),
+ },
+ {
+ addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(10),
+ },
+ {
+ addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(11),
+ },
+ {
+ addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(12),
+ },
+ {
+ addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(13),
+ },
+ {
+ addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6GlobalMulticastScope,
+ },
+ {
+ addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6ReservedFMulticastScope,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.V6MulticastScope(test.addr); got != test.want {
+ t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index d0a1a2492..1b5093e58 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
@@ -40,13 +41,13 @@ func TestNDPNeighborSolicit(t *testing.T) {
// Test getting the Target Address.
ns := NDPNeighborSolicit(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
if got := ns.TargetAddress(); got != addr {
t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
}
// Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
ns.SetTargetAddress(addr2)
if got := ns.TargetAddress(); got != addr2 {
t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
@@ -69,7 +70,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
// Test getting the Target Address.
na := NDPNeighborAdvert(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
if got := na.TargetAddress(); got != addr {
t.Errorf("got TargetAddress = %s, want %s", got, addr)
}
@@ -90,7 +91,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
// Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
na.SetTargetAddress(addr2)
if got := na.TargetAddress(); got != addr2 {
t.Errorf("got TargetAddress = %s, want %s", got, addr2)
@@ -277,7 +278,7 @@ func TestOpts(t *testing.T) {
}
const validLifetimeSeconds = 16909060
- const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18")
+ address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718")
expectedRDNSSBytes := [...]byte{
// Type, Length
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index adc835d30..0df517000 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -216,104 +216,104 @@ const (
TCPDefaultMSS = 536
)
-// SourcePort returns the "source port" field of the tcp header.
+// SourcePort returns the "source port" field of the TCP header.
func (b TCP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
}
-// DestinationPort returns the "destination port" field of the tcp header.
+// DestinationPort returns the "destination port" field of the TCP header.
func (b TCP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
}
-// SequenceNumber returns the "sequence number" field of the tcp header.
+// SequenceNumber returns the "sequence number" field of the TCP header.
func (b TCP) SequenceNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
}
-// AckNumber returns the "ack number" field of the tcp header.
+// AckNumber returns the "ack number" field of the TCP header.
func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
-// DataOffset returns the "data offset" field of the tcp header. The return
+// DataOffset returns the "data offset" field of the TCP header. The return
// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
-// Payload returns the data in the tcp packet.
+// Payload returns the data in the TCP packet.
func (b TCP) Payload() []byte {
return b[b.DataOffset():]
}
-// Flags returns the flags field of the tcp header.
+// Flags returns the flags field of the TCP header.
func (b TCP) Flags() TCPFlags {
return TCPFlags(b[TCPFlagsOffset])
}
-// WindowSize returns the "window size" field of the tcp header.
+// WindowSize returns the "window size" field of the TCP header.
func (b TCP) WindowSize() uint16 {
return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
}
-// Checksum returns the "checksum" field of the tcp header.
+// Checksum returns the "checksum" field of the TCP header.
func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
-// UrgentPointer returns the "urgent pointer" field of the tcp header.
+// UrgentPointer returns the "urgent pointer" field of the TCP header.
func (b TCP) UrgentPointer() uint16 {
return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
}
-// SetSourcePort sets the "source port" field of the tcp header.
+// SetSourcePort sets the "source port" field of the TCP header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
}
-// SetDestinationPort sets the "destination port" field of the tcp header.
+// SetDestinationPort sets the "destination port" field of the TCP header.
func (b TCP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
}
-// SetChecksum sets the checksum field of the tcp header.
+// SetChecksum sets the checksum field of the TCP header.
func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
}
-// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// SetDataOffset sets the data offset field of the TCP header. headerLen should
// be the length of the TCP header in bytes.
func (b TCP) SetDataOffset(headerLen uint8) {
b[TCPDataOffset] = (headerLen / 4) << 4
}
-// SetSequenceNumber sets the sequence number field of the tcp header.
+// SetSequenceNumber sets the sequence number field of the TCP header.
func (b TCP) SetSequenceNumber(seqNum uint32) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
}
-// SetAckNumber sets the ack number field of the tcp header.
+// SetAckNumber sets the ack number field of the TCP header.
func (b TCP) SetAckNumber(ackNum uint32) {
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
}
-// SetFlags sets the flags field of the tcp header.
+// SetFlags sets the flags field of the TCP header.
func (b TCP) SetFlags(flags uint8) {
b[TCPFlagsOffset] = flags
}
-// SetWindowSize sets the window size field of the tcp header.
+// SetWindowSize sets the window size field of the TCP header.
func (b TCP) SetWindowSize(rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
-// SetUrgentPoiner sets the window size field of the tcp header.
+// SetUrgentPoiner sets the window size field of the TCP header.
func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
}
-// CalculateChecksum calculates the checksum of the tcp segment.
+// CalculateChecksum calculates the checksum of the TCP segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
@@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
return Checksum(b[:b.DataOffset()], partialChecksum)
}
+// IsChecksumValid returns true iff the TCP header's checksum is valid.
+func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
+ xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength)
+ xsum = ChecksumCombine(xsum, payloadChecksum)
+ return b.CalculateChecksum(xsum) == 0xffff
+}
+
// Options returns a slice that holds the unparsed TCP options in the segment.
func (b TCP) Options() []byte {
return b[TCPMinimumSize:b.DataOffset()]
@@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
-// Encode encodes all the fields of the tcp header.
+// Encode encodes all the fields of the TCP header.
func (b TCP) Encode(t *TCPFields) {
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort)
@@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer)
}
-// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// EncodePartial updates a subset of the fields of the TCP header. It is useful
// in cases when similar segments are produced.
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
@@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
}
// ParseSynOptions parses the options received in a SYN segment and returns the
-// relevant ones. opts should point to the option part of the TCP Header.
+// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
limit := len(opts)
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 98bdd29db..ae9d167ff 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -64,17 +64,17 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
+// SourcePort returns the "source port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "destination port" field of the udp header.
+// DestinationPort returns the "destination port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "length" field of the udp header.
+// Length returns the "length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
@@ -84,39 +84,46 @@ func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the udp header.
+// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the udp header.
+// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the udp header.
+// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the udp header.
+// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the udp header.
+// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the
+// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return Checksum(b[:UDPMinimumSize], partialChecksum)
}
-// Encode encodes all the fields of the udp header.
+// IsChecksumValid returns true iff the UDP header's checksum is valid.
+func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
+ xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length())
+ xsum = ChecksumCombine(xsum, payloadChecksum)
+ return b.CalculateChecksum(xsum) == 0xffff
+}
+
+// Encode encodes all the fields of the UDP header.
func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index cd76272de..ef9126deb 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -30,7 +30,6 @@ import (
type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
- GSO *stack.GSO
Route stack.RouteInfo
}
@@ -124,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) {
q.notify = notify
}
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -131,6 +133,7 @@ type Endpoint struct {
mtu uint32
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
+ SupportedGSOKind stack.SupportedGSO
// Outbound packet queue.
q *queue
@@ -212,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.LinkEPCapabilities
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (*Endpoint) GSOMaxSize() uint32 {
return 1 << 15
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.SupportedGSOKind
+}
+
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
@@ -229,11 +237,10 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
- GSO: gso,
Route: r,
}
@@ -243,13 +250,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
- GSO: gso,
Route: r,
}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index d873766a6..b427c6170 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -61,20 +61,20 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
- return e.Endpoint.WritePacket(r, gso, proto, pkt)
+ return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
}
- return e.Endpoint.WritePackets(r, gso, pkts, proto)
+ return e.Endpoint.WritePackets(r, pkts, proto)
}
// MaxHeaderLength implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index f042df82e..d971194e6 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/binary",
"//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 6be945116..bddb1d0a2 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,7 +45,6 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string {
}
}
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -134,6 +136,9 @@ type endpoint struct {
// wg keeps track of running goroutines.
wg sync.WaitGroup
+
+ // gsoKind is the supported kind of GSO.
+ gsoKind stack.SupportedGSO
}
// Options specify the details about the fd-based endpoint to be created.
@@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
if isSocket {
if opts.GSOMaxSize != 0 {
if opts.SoftwareGSOEnabled {
- e.caps |= stack.CapabilitySoftwareGSO
+ e.gsoKind = stack.SWGSOSupported
} else {
- e.caps |= stack.CapabilityHardwareGSO
+ e.gsoKind = stack.HWGSOSupported
}
e.gsoMaxSize = opts.GSOMaxSize
}
@@ -403,6 +408,35 @@ type virtioNetHdr struct {
csumOffset uint16
}
+// marshal serializes h to a newly-allocated byte slice, in little-endian byte
+// order.
+//
+// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used
+// for general serialization. This makes it difficult to use go-marshal for
+// virtio types, as go-marshal implicitly uses the native byte ordering.
+func (h *virtioNetHdr) marshal() []byte {
+ buf := [virtioNetHdrSize]byte{
+ 0: byte(h.flags),
+ 1: byte(h.gsoType),
+
+ // Manually lay out the fields in little-endian byte order. Little endian =>
+ // least significant bit goes to the lower address.
+
+ 2: byte(h.hdrLen),
+ 3: byte(h.hdrLen >> 8),
+
+ 4: byte(h.gsoSize),
+ 5: byte(h.gsoSize >> 8),
+
+ 6: byte(h.csumStart),
+ 7: byte(h.csumStart >> 8),
+
+ 8: byte(h.csumOffset),
+ 9: byte(h.csumOffset >> 8),
+ }
+ return buf[:]
+}
+
// These constants are declared in linux/virtio_net.h.
const (
_VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
@@ -433,7 +467,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if e.hdrSize > 0 {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
@@ -441,29 +475,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
- if gso != nil {
+ if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
- if gso.NeedsCsum {
+ if pkt.GSOOptions.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
- vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
- vnetHdr.csumOffset = gso.CsumOffset
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
}
- if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS {
- switch gso.Type {
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
case stack.GSOTCPv4:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
case stack.GSOTCPv6:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
default:
- panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
}
- vnetHdr.gsoSize = gso.MSS
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
- vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf := vnetHdr.marshal()
builder.Add(vnetHdrBuf)
}
@@ -482,9 +516,9 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
}
var vnetHdrBuf []byte
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
- if pkt.GSOOptions != nil {
+ if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if pkt.GSOOptions.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
@@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
- vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf = vnetHdr.marshal()
}
var builder iovec.Builder
@@ -540,7 +574,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
// - pkt.EgressRoute
// - pkt.GSOOptions
// - pkt.NetworkProtocolNumber
-func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
// Preallocate to avoid repeated reallocation as we append to batch.
// batchSz is 47 because when SWGSO is in use then a single 65KB TCP
// segment can get split into 46 segments of 1420 bytes and a single 216
@@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
}
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// SupportsHWGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.gsoKind
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
if e.hdrSize > 0 {
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 1e40f3fef..8aad338b6 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -207,18 +207,17 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
// Write.
want := append(append(buffer.View(nil), b...), payload...)
- var gso *stack.GSO
+ const l3HdrLen = header.IPv6MinimumSize
if gsoMaxSize != 0 {
- gso = &stack.GSO{
+ pkt.GSOOptions = stack.GSO{
Type: stack.GSOTCPv6,
NeedsCsum: true,
CsumOffset: csumOffset,
MSS: gsoMSS,
- MaxSize: gsoMaxSize,
- L3HdrLen: header.IPv4MaximumHeaderSize,
+ L3HdrLen: l3HdrLen,
}
}
- if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -235,7 +234,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 {
t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM)
}
- csumStart := header.EthernetMinimumSize + gso.L3HdrLen
+ const csumStart = header.EthernetMinimumSize + l3HdrLen
if vnetHdr.csumStart != csumStart {
t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart)
}
@@ -243,7 +242,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset)
}
gsoType := uint8(0)
- if int(gso.MSS) < plen {
+ if plen > gsoMSS {
gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
}
if vnetHdr.gsoType != gsoType {
@@ -333,7 +332,7 @@ func TestPreserveSrcAddress(t *testing.T) {
ReserveHeaderBytes: header.EthernetMinimumSize,
Data: buffer.VectorisedView{},
})
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index a7adf822b..4b7ef3aac 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -128,7 +128,7 @@ type readVDispatcher struct {
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
}
@@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
for i := range d.bufs {
d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 691467870..7012d8829 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// Construct data as the unparsed portion for the loopback packet.
data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 668f72eee..3e2a1aa94 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -87,20 +87,20 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber,
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
return 0, &tcpip.ErrNoRoute{}
}
- return endpoint.WritePackets(r, gso, pkts, protocol)
+ return endpoint.WritePackets(r, pkts, protocol)
}
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
- return endpoint.WritePacket(r, gso, protocol, pkt)
+ return endpoint.WritePacket(r, protocol, pkt)
}
return &tcpip.ErrNoRoute{}
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 5806f7fdf..040e3a35b 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -54,7 +54,7 @@ func TestInjectableEndpointDispatch(t *testing.T) {
var packetRoute stack.RouteInfo
packetRoute.RemoteAddress = dstIP
- endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
+ endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
@@ -76,7 +76,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
pkt.TransportHeader().Push(1)[0] = 0xFA
var packetRoute stack.RouteInfo
packetRoute.RemoteAddress = dstIP
- endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
+ endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
if err != nil {
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 97ad9fdd5..3e816b0c7 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -113,13 +113,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return e.child.WritePacket(r, gso, protocol, pkt)
+func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WritePacket(r, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return e.child.WritePackets(r, gso, pkts, protocol)
+func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return e.child.WritePackets(r, pkts, protocol)
}
// Wait implements stack.LinkEndpoint.
@@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index 6cbe18a56..e01837e2d 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -35,16 +35,16 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
- return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+ return e.Endpoint.WritePacket(r, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
- return e.Endpoint.WritePackets(r, gso, pkts, proto)
+ return e.Endpoint.WritePackets(r, pkts, proto)
}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 21fb87757..5030b6ba1 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -66,7 +66,7 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var pkts stack.PacketBufferList
pkts.PushBack(pkt)
e.deliverPackets(r, proto, pkts)
@@ -74,7 +74,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := pkts.Len()
e.deliverPackets(r, proto, pkts)
return n, nil
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 128ef6e87..b1a28491d 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -25,6 +25,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
// endpoint represents a LinkEndpoint which implements a FIFO queue for all
// outgoing packets. endpoint can have 1 or more underlying queueDispatchers.
// All outgoing packets are consistenly hashed to a single underlying queue
@@ -91,7 +94,7 @@ func (q *queueDispatcher) dispatchLoop() {
}
// We pass a protocol of zero here because each packet carries its
// NetworkProtocol.
- q.lower.WritePackets(stack.RouteInfo{}, nil /* gso */, batch, 0 /* protocol */)
+ q.lower.WritePackets(stack.RouteInfo{}, batch, 0 /* protocol */)
for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() {
batch.Remove(pkt)
}
@@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.lower.LinkAddress()
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
if gso, ok := e.lower.(stack.GSOEndpoint); ok {
return gso.GSOMaxSize()
@@ -149,13 +152,21 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- // WritePacket caller's do not set the following fields in PacketBuffer
- // so we populate them here.
- pkt.EgressRoute = r
- pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = protocol
+//
+// The packet must have the following fields populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
if !d.q.enqueue(pkt) {
return &tcpip.ErrNoBufferSpace{}
@@ -166,12 +177,12 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements stack.LinkEndpoint.WritePackets.
//
-// Being a batch API, each packet in pkts should have the following
-// fields populated:
+// Each packet in the packet buffer list must have the following fields
+// populated:
// - pkt.EgressRoute
// - pkt.GSOOptions
// - pkt.NetworkProtocolNumber
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
enqueued := 0
for pkt := pkts.Front(); pkt != nil; {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index d8d0b16b2..df9a0b90a 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
views := pkt.Views()
@@ -220,7 +220,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d4b3ddd5c..0f72d4e95 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -281,7 +281,7 @@ func TestSimpleSend(t *testing.T) {
copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -351,7 +351,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
})
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -407,7 +407,7 @@ func TestFillTxQueue(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -425,7 +425,7 @@ func TestFillTxQueue(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
@@ -453,7 +453,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -476,7 +476,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -494,7 +494,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
@@ -520,7 +520,7 @@ func TestFillTxMemory(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -539,7 +539,7 @@ func TestFillTxMemory(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
@@ -566,7 +566,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -581,7 +581,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buffer.NewView(bufferSize).ToVectorisedView(),
})
- err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
@@ -593,7 +593,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
+ if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 7aaee3d13..2d6a3a833 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -139,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket(directionRecv, nil, protocol, pkt)
+ e.dumpPacket(directionRecv, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -148,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(e.logPrefix, dir, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -187,22 +187,22 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.dumpPacket(directionSend, gso, protocol, pkt)
- return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.dumpPacket(directionSend, protocol, pkt)
+ return e.Endpoint.WritePacket(r, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket(directionSend, gso, protocol, pkt)
+ e.dumpPacket(directionSend, protocol, pkt)
}
- return e.Endpoint.WritePackets(r, gso, pkts, protocol)
+ return e.Endpoint.WritePackets(r, pkts, protocol)
}
-func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -411,8 +411,8 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
return
}
- if gso != nil {
- details += fmt.Sprintf(" gso: %+v", gso)
+ if pkt.GSOOptions.Type != stack.GSONone {
+ details += fmt.Sprintf(" gso: %#v", pkt.GSOOptions)
}
log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index ce5113746..a95602aa5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -108,12 +108,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WritePacket(r, gso, protocol, pkt)
+ err := e.lower.WritePacket(r, protocol, pkt)
e.writeGate.Leave()
return err
}
@@ -121,12 +121,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
if !e.writeGate.Enter() {
return pkts.Len(), nil
}
- n, err := e.lower.WritePackets(r, gso, pkts, protocol)
+ n, err := e.lower.WritePackets(r, pkts, protocol)
e.writeGate.Leave()
return n, err
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index e368a9eaa..a71400ee9 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+func (e *countedEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
e.writeCount += pkts.Len()
return pkts.Len(), nil
}
@@ -98,21 +98,21 @@ func TestWaitWrite(t *testing.T) {
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
+ wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
+ wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
+ wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index fa8814bac..7b1ff44f4 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -21,6 +21,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d59d678b2..6905b9ccb 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -33,6 +33,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7ae38d684..0efa3a926 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -136,7 +136,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (*endpoint) Close() {}
-func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
+func (*endpoint) WritePacket(*stack.Route, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -146,7 +146,7 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (*endpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
return 0, &tcpip.ErrNotSupported{}
}
@@ -222,7 +222,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
//
// Send the packet to the (new) target hardware address on the same
// hardware on which the request was received.
- if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil {
+ if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), ProtocolNumber, respPkt); err != nil {
stats.outgoingRepliesDropped.Increment()
} else {
stats.outgoingRepliesSent.Increment()
@@ -355,7 +355,7 @@ func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLin
}
stats := e.stats.arp
- if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 018d6a578..94209b026 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -30,20 +30,16 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
const (
nicID = 1
- stackAddr = tcpip.Address("\x0a\x00\x00\x01")
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
-
- remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
- unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
-
defaultChannelSize = 1
defaultMTU = 65536
@@ -54,6 +50,12 @@ const (
eventChanSize = 32
)
+var (
+ stackAddr = testutil.MustParse4("10.0.0.1")
+ remoteAddr = testutil.MustParse4("10.0.0.2")
+ unknownAddr = testutil.MustParse4("10.0.0.3")
+)
+
type eventType uint8
const (
@@ -449,12 +451,12 @@ type testLinkEndpoint struct {
writeErr tcpip.Error
}
-func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
- return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
+ return t.LinkEndpoint.WritePacket(r, protocol, pkt)
}
func TestLinkAddressRequest(t *testing.T) {
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
index d21b4c7ef..fd944ce99 100644
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ b/pkg/tcpip/network/internal/ip/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ip",
srcs = [
"duplicate_address_detection.go",
+ "errors.go",
"generic_multicast_protocol.go",
"stats.go",
],
diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go
new file mode 100644
index 000000000..50fabfd79
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/errors.go
@@ -0,0 +1,77 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// ForwardingError represents an error that occured while trying to forward
+// a packet.
+type ForwardingError interface {
+ isForwardingError()
+ fmt.Stringer
+}
+
+// ErrTTLExceeded indicates that the received packet's TTL has been exceeded.
+type ErrTTLExceeded struct{}
+
+func (*ErrTTLExceeded) isForwardingError() {}
+
+func (*ErrTTLExceeded) String() string { return "ttl exceeded" }
+
+// ErrIPOptProblem indicates the received packet had a problem with an IP
+// option.
+type ErrIPOptProblem struct{}
+
+func (*ErrIPOptProblem) isForwardingError() {}
+
+func (*ErrIPOptProblem) String() string { return "ip option problem" }
+
+// ErrLinkLocalSourceAddress indicates the received packet had a link-local
+// source address.
+type ErrLinkLocalSourceAddress struct{}
+
+func (*ErrLinkLocalSourceAddress) isForwardingError() {}
+
+func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" }
+
+// ErrLinkLocalDestinationAddress indicates the received packet had a link-local
+// destination address.
+type ErrLinkLocalDestinationAddress struct{}
+
+func (*ErrLinkLocalDestinationAddress) isForwardingError() {}
+
+func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" }
+
+// ErrNoRoute indicates the Netstack couldn't find a route for the
+// received packet.
+type ErrNoRoute struct{}
+
+func (*ErrNoRoute) isForwardingError() {}
+
+func (*ErrNoRoute) String() string { return "no route" }
+
+// ErrOther indicates the packet coould not be forwarded for a reason
+// captured by the contained error.
+type ErrOther struct {
+ Err tcpip.Error
+}
+
+func (*ErrOther) isForwardingError() {}
+
+func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) }
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index b9f129728..d22974b12 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ip holds IPv4/IPv6 common utilities.
package ip
import (
@@ -156,14 +155,6 @@ type GenericMulticastProtocolOptions struct {
//
// Unsolicited reports are transmitted when a group is newly joined.
MaxUnsolicitedReportDelay time.Duration
-
- // AllNodesAddress is a multicast address that all nodes on a network should
- // be a member of.
- //
- // This address will not have the generic multicast protocol performed on it;
- // it will be left in the non member/listener state, and packets will never
- // be sent for it.
- AllNodesAddress tcpip.Address
}
// MulticastGroupProtocol is a multicast group protocol whose core state machine
@@ -188,6 +179,10 @@ type MulticastGroupProtocol interface {
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) tcpip.Error
+
+ // ShouldPerformProtocol returns true iff the protocol should be performed for
+ // the specified group.
+ ShouldPerformProtocol(tcpip.Address) bool
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
@@ -455,20 +450,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
info.lastToSendReport = false
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
info.state = idleMember
return
}
@@ -537,20 +519,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
@@ -627,20 +596,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 381460c82..0b51563cd 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct {
type mockMulticastGroupProtocol struct {
t *testing.T
+ skipProtocolAddress tcpip.Address
+
mu mockMulticastGroupProtocolProtectedFields
}
@@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip
return nil
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ return groupAddress != m.skipProtocolAddress
+}
+
func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
m.mu.Lock()
defer m.mu.Unlock()
@@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
cmp.FilterPath(
func(p cmp.Path) bool {
switch p.Last().String() {
- case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
return true
+ default:
+ return false
}
- return false
},
cmp.Ignore(),
),
@@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
// Joining a group should send a report immediately and another after
@@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
mgp.joinGroup(test.addr)
@@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) {
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index b6f39ddb1..392f0b0c7 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -18,70 +18,114 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// LINT.IfChange(MultiCounterIPStats)
+// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter
+// may have several versions.
+type MultiCounterIPForwardingStats struct {
+ // Unrouteable is the number of IP packets received which were dropped
+ // because the netstack could not construct a route to their
+ // destination.
+ Unrouteable tcpip.MultiCounterStat
+
+ // ExhaustedTTL is the number of IP packets received which were dropped
+ // because their TTL was exhausted.
+ ExhaustedTTL tcpip.MultiCounterStat
+
+ // LinkLocalSource is the number of IP packets which were dropped
+ // because they contained a link-local source address.
+ LinkLocalSource tcpip.MultiCounterStat
+
+ // LinkLocalDestination is the number of IP packets which were dropped
+ // because they contained a link-local destination address.
+ LinkLocalDestination tcpip.MultiCounterStat
+
+ // Errors is the number of IP packets received which could not be
+ // successfully forwarded.
+ Errors tcpip.MultiCounterStat
+}
+
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
- // PacketsReceived is the total number of IP packets received from the link
+ // PacketsReceived is the number of IP packets received from the link
// layer.
PacketsReceived tcpip.MultiCounterStat
- // DisabledPacketsReceived is the total number of IP packets received from the
- // link layer when the IP layer is disabled.
+ // DisabledPacketsReceived is the number of IP packets received from
+ // the link layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
- // InvalidDestinationAddressesReceived is the total number of IP packets
+ // InvalidDestinationAddressesReceived is the number of IP packets
// received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
- // InvalidSourceAddressesReceived is the total number of IP packets received
- // with a source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received
+ // with a source address that should never have been received on the
+ // wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
- // PacketsDelivered is the total number of incoming IP packets that are
+ // PacketsDelivered is the number of incoming IP packets that are
// successfully delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
- // PacketsSent is the total number of IP packets sent via WritePacket.
+ // PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
- // OutgoingPacketErrors is the total number of IP packets which failed to
+ // OutgoingPacketErrors is the number of IP packets which failed to
// write to a link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
- // MalformedPacketsReceived is the total number of IP Packets that were
+ // MalformedPacketsReceived is the number of IP Packets that were
// dropped due to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
- // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // MalformedFragmentsReceived is the number of IP Fragments that were
// dropped due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
- // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
- // IPTablesInputDropped is the total number of IP packets dropped in the Input
- // chain.
+ // IPTablesInputDropped is the number of IP packets dropped in the
+ // Input chain.
IPTablesInputDropped tcpip.MultiCounterStat
- // IPTablesOutputDropped is the total number of IP packets dropped in the
+ // IPTablesOutputDropped is the number of IP packets dropped in the
// Output chain.
IPTablesOutputDropped tcpip.MultiCounterStat
- // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
- // of IPStats.
+ // IPTablesPostroutingDropped is the number of IP packets dropped in
+ // the Postrouting chain.
+ IPTablesPostroutingDropped tcpip.MultiCounterStat
+
+ // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option
+ // stats out of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
OptionTimestampReceived tcpip.MultiCounterStat
- // OptionRecordRouteReceived is the number of Record Route options seen.
+ // OptionRecordRouteReceived is the number of Record Route options
+ // seen.
OptionRecordRouteReceived tcpip.MultiCounterStat
- // OptionRouterAlertReceived is the number of Router Alert options seen.
+ // OptionRouterAlertReceived is the number of Router Alert options
+ // seen.
OptionRouterAlertReceived tcpip.MultiCounterStat
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived tcpip.MultiCounterStat
+
+ // Forwarding collects stats related to IP forwarding.
+ Forwarding MultiCounterIPForwardingStats
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
+ m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
+ m.Errors.Init(a.Errors, b.Errors)
+ m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
+ m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
+ m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
}
// Init sets internal counters to track a and b counters.
@@ -98,10 +142,12 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+ m.Forwarding.Init(&a.Forwarding, &b.Forwarding)
}
// LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index f5fa77b65..e2cf24b67 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -64,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if ep.allowPackets == 0 {
return ep.err
}
@@ -74,11 +74,11 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
var n int
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ if err := ep.WritePacket(r, protocol, pkt); err != nil {
return n, err
}
n++
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a4edc69c7..74aad126c 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "fmt"
"strings"
"testing"
@@ -29,23 +30,25 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
-const (
- localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
- remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
- ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
- ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
- ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
- localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
- ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
- ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- nicID = 1
+const nicID = 1
+
+var (
+ localIPv4Addr = testutil.MustParse4("10.0.0.1")
+ remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
+ ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
+ ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
+ ipv4Gateway = testutil.MustParse4("10.0.0.3")
+ localIPv6Addr = testutil.MustParse6("a00::1")
+ remoteIPv6Addr = testutil.MustParse6("a00::2")
+ ipv6SubnetAddr = testutil.MustParse6("a00::")
+ ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
+ ipv6Gateway = testutil.MustParse6("a00::3")
)
var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
@@ -180,7 +183,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -202,7 +205,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
@@ -323,7 +326,7 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -588,7 +591,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -1015,7 +1018,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -1938,3 +1941,80 @@ func TestICMPInclusionSize(t *testing.T) {
})
}
}
+
+func TestJoinLeaveAllRoutersGroup(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ protoFactory stack.NetworkProtocolFactory
+ allRoutersAddr tcpip.Address
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ protoFactory: ipv4.NewProtocol,
+ allRoutersAddr: header.IPv4AllRoutersGroup,
+ },
+ {
+ name: "IPv6 Interface Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Link Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Site Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, nicDisabled := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+ opts := stack.NICOptions{Disabled: nicDisabled}
+ if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
+ }
+
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if !got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, false); err != nil {
+ t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 5e7f10f4b..7ee0495d9 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -45,6 +45,7 @@ go_test(
"//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1525f15db..c8ed1ce79 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
return
}
- // Skip the ip header, then deliver the error.
- pkt.Data().TrimFront(hlen)
+ // Keep needed information before trimming header.
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
+ dstAddr := hdr.DestinationAddress()
+ // Skip the ip header, then deliver the error.
+ pkt.Data().DeleteFront(hlen)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
@@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
- pkt.Data().TrimFront(header.ICMPv4MinimumSize)
- switch h.Code() {
+ mtu := h.MTU()
+ code := h.Code()
+ pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
- networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
@@ -442,6 +446,23 @@ func (r *icmpReasonParamProblem) isForwarding() bool {
return r.forwarding
}
+// icmpReasonNetworkUnreachable is an error in which the network specified in
+// the internet destination field of the datagram is unreachable.
+type icmpReasonNetworkUnreachable struct{}
+
+func (*icmpReasonNetworkUnreachable) isICMPReason() {}
+func (*icmpReasonNetworkUnreachable) isForwarding() bool {
+ // If we hit a Net Unreachable error, then we know we are operating as
+ // a router. As per RFC 792 page 5, Destination Unreachable Message,
+ //
+ // If, according to the information in the gateway's routing tables,
+ // the network specified in the internet destination field of a
+ // datagram is unreachable, e.g., the distance to the network is
+ // infinity, the gateway may send a destination unreachable message to
+ // the internet source host of the datagram.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -610,6 +631,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetworkUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4NetUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
@@ -629,7 +654,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
- nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
TTL: route.DefaultTTL(),
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index f3fc1c87e..3ce499298 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2236 section 6 page 10,
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ return groupAddress != header.IPv4AllSystems
+}
+
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
//
@@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv4AllSystems,
})
igmp.igmpV1Present = igmpV1PresentDefault
igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
@@ -331,7 +341,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
}
sentStats := igmp.ep.stats.igmp.packetsSent
- if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), ProtocolNumber, pkt); err != nil {
sentStats.dropped.Increment()
return false, err
}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index e5e1b89cc..4bd6f462e 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -26,18 +26,22 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- stackAddr = tcpip.Address("\x0a\x00\x00\x01")
- remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
- multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
defaultTTL = 1
defaultPrefixLength = 24
)
+var (
+ stackAddr = testutil.MustParse4("10.0.0.1")
+ remoteAddr = testutil.MustParse4("10.0.0.2")
+ multicastAddr = testutil.MustParse4("224.0.0.3")
+)
+
// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
// sent to the provided address with the passed fields set. Raises a t.Error if
// any field does not match.
@@ -292,7 +296,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPLeaveGroup,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
@@ -302,7 +306,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPMembershipQuery,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: true,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
@@ -312,7 +316,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv1MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
@@ -322,7 +326,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv2MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
@@ -332,7 +336,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv2MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{
- {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24},
+ {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
{Address: stackAddr, PrefixLen: 24},
},
srcAddr: remoteAddr,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1a5661ca4..b11e56c6a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -150,6 +151,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if forwarding {
+ // There does not seem to be an RFC requirement for a node to join the all
+ // routers multicast address but
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml
+ // specifies the address as a group for all routers on a subnet so we join
+ // the group here.
+ if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+
+ return
+ }
+
+ switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+}
+
// Enable implements stack.NetworkEndpoint.
func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
@@ -226,7 +259,7 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
@@ -318,7 +351,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -338,7 +371,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := e.addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* options */); err != nil {
return err
}
@@ -346,7 +379,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -369,10 +402,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
}
- return e.writePacket(r, gso, pkt, false /* headerIncluded */)
+ return e.writePacket(r, pkt, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop()&stack.PacketLoop != 0 {
// If the packet was generated by the stack (not a raw/packet endpoint
// where a packet may be written with the header included), then we can
@@ -383,6 +416,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
@@ -391,20 +433,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return err
}
- if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
+ if packetMustBeFragmented(pkt, networkMTU) {
+ sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
- return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ return e.nic.WritePacket(r, ProtocolNumber, fragPkt)
})
stats.PacketsSent.IncrementBy(uint64(sent))
stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil {
stats.OutgoingPacketErrors.Increment()
return err
}
@@ -413,7 +455,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
@@ -434,11 +476,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return 0, err
}
- if packetMustBeFragmented(pkt, networkMTU, gso) {
+ if packetMustBeFragmented(pkt, networkMTU) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
+ if _, _, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
@@ -454,9 +496,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -478,14 +520,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
- written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ written, err := e.nic.WritePackets(r, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(written))
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -545,12 +596,31 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return &tcpip.ErrMalformedHeader{}
}
- return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
+ return e.writePacket(r, pkt, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv4(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ // As per RFC 3927 section 7,
+ //
+ // A router MUST NOT forward a packet with an IPv4 Link-Local source or
+ // destination address, irrespective of the router's default route
+ // configuration or routes obtained from dynamic routing protocols.
+ //
+ // A router which receives a packet with an IPv4 Link-Local source or
+ // destination address MUST NOT forward the packet. This prevents
+ // forwarding of packets back onto the network segment from which they
+ // originated, or to any other segment.
+ if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
+ }
+
ttl := h.TTL()
if ttl == 0 {
// As per RFC 792 page 6, Time Exceeded Message,
@@ -558,7 +628,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// If the gateway processing a datagram finds the time to live field
// is zero it must discard the datagram. The gateway may also notify
// the source host via the time exceeded message.
- return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
if opts := h.Options(); len(opts) != 0 {
@@ -569,10 +644,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
pointer: optProblem.Pointer,
forwarding: true,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- e.stats.ip.MalformedPacketsReceived.Increment()
}
- return nil // option problems are not reported locally.
+ return &ip.ErrIPOptProblem{}
}
copied := copy(opts, newOpts)
if copied != len(newOpts) {
@@ -589,8 +662,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -598,8 +669,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
@@ -616,10 +695,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ })); err != nil {
+ return &ip.ErrOther{Err: err}
+ }
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -668,7 +750,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -734,14 +816,31 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
- _ = e.forwardPacket(pkt)
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrIPOptProblem:
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ stats.ip.Forwarding.Errors.Increment()
return
}
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -1114,28 +1213,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
return nil, false
}
- // There has been some confusion regarding verifying checksums. We need
- // just look for negative 0 (0xffff) as the checksum, as it's not possible to
- // get positive 0 (0) for the checksum. Some bad implementations could get it
- // when doing entry replacement in the early days of the Internet,
- // however the lore that one needs to check for both persists.
- //
- // RFC 1624 section 1 describes the source of this confusion as:
- // [the partial recalculation method described in RFC 1071] computes a
- // result for certain cases that differs from the one obtained from
- // scratch (one's complement of one's complement sum of the original
- // fields).
- //
- // However RFC 1624 section 5 clarifies that if using the verification method
- // "recommended by RFC 1071, it does not matter if an intermediate system
- // generated a -0 instead of +0".
- //
- // RFC1071 page 1 specifies the verification method as:
- // (3) To check a checksum, the 1's complement sum is computed over the
- // same set of octets, including the checksum field. If the result
- // is all 1 bits (-0 in 1's complement arithmetic), the check
- // succeeds.
- if h.CalculateChecksum() != 0xffff {
+ if !h.IsChecksumValid() {
return nil, false
}
@@ -1168,12 +1246,27 @@ func (p *protocol) Forwarding() bool {
return uint8(atomic.LoadUint32(&p.forwarding)) == 1
}
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
+ }
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
+}
+
// SetForwarding implements stack.ForwardingNetworkProtocol.
func (p *protocol) SetForwarding(v bool) {
- if v {
- atomic.StoreUint32(&p.forwarding, 1)
- } else {
- atomic.StoreUint32(&p.forwarding, 0)
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for _, ep := range p.mu.eps {
+ ep.transitionForwarding(v)
}
}
@@ -1200,9 +1293,9 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error
return networkMTU - uint32(networkHeaderSize), nil
}
-func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool {
payload := pkt.TransportHeader().View().Size() + pkt.Data().Size()
- return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
+ return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU
}
// addressToUint32 translates an IPv4 address into its little endian uint32
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index eba91c68c..7a7cad04a 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -39,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -130,48 +131,69 @@ func TestForwarding(t *testing.T) {
}
remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+ unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4())
+ multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4())
+ linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4())
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
- options header.IPv4Options
- forwardedOptions header.IPv4Options
- icmpType header.ICMPv4Type
- icmpCode header.ICMPv4Code
+ name string
+ TTL uint8
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ expectErrorICMP bool
+ expectPacketForwarded bool
+ options header.IPv4Options
+ forwardedOptions header.IPv4Options
+ icmpType header.ICMPv4Type
+ icmpCode header.ICMPv4Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
}{
{
name: "TTL of zero",
TTL: 0,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
expectErrorICMP: true,
icmpType: header.ICMPv4TimeExceeded,
icmpCode: header.ICMPv4TTLExceeded,
},
{
- name: "TTL of one",
- TTL: 1,
- expectErrorICMP: false,
+ name: "TTL of one",
+ TTL: 1,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "TTL of two",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
},
{
- name: "four EOL options",
- TTL: 2,
- expectErrorICMP: false,
- options: header.IPv4Options{0, 0, 0, 0},
- forwardedOptions: header.IPv4Options{0, 0, 0, 0},
+ name: "four EOL options",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ options: header.IPv4Options{0, 0, 0, 0},
+ forwardedOptions: header.IPv4Options{0, 0, 0, 0},
},
{
- name: "TS type 1 full",
- TTL: 2,
+ name: "TS type 1 full",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
options: header.IPv4Options{
68, 12, 13, 0xF1,
192, 168, 1, 12,
@@ -182,8 +204,10 @@ func TestForwarding(t *testing.T) {
icmpCode: header.ICMPv4UnusedCode,
},
{
- name: "TS type 0",
- TTL: 2,
+ name: "TS type 0",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
options: header.IPv4Options{
68, 24, 21, 0x00,
1, 2, 3, 4,
@@ -200,10 +224,13 @@ func TestForwarding(t *testing.T) {
13, 14, 15, 16,
0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
},
+ expectPacketForwarded: true,
},
{
- name: "end of options list",
- TTL: 2,
+ name: "end of options list",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
options: header.IPv4Options{
68, 12, 13, 0x11,
192, 168, 1, 12,
@@ -219,6 +246,37 @@ func TestForwarding(t *testing.T) {
0, 0, 0, // 7 bytes unknown option removed.
0, 0, 0, 0,
},
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Network unreachable",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: unreachableIPv4Addr,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4NetUnreachable,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Multicast destination",
+ TTL: 2,
+ destAddr: multicastIPv4Addr,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: linkLocalIPv4Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv4Addr,
+ destAddr: remoteIPv4Addr2,
+ expectLinkLocalSourceError: true,
},
}
for _, test := range tests {
@@ -286,8 +344,8 @@ func TestForwarding(t *testing.T) {
TotalLength: totalLen,
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: test.TTL,
- SrcAddr: remoteIPv4Addr1,
- DstAddr: remoteIPv4Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
})
if len(test.options) != 0 {
ip.SetHeaderLength(uint8(ipHeaderLength))
@@ -304,15 +362,15 @@ func TestForwarding(t *testing.T) {
})
e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e1.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
checker.SrcAddr(ipv4Addr1.Address),
- checker.DstAddr(remoteIPv4Addr1),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(ipv4.DefaultTTL),
checker.ICMPv4(
checker.ICMPv4Checksum(),
@@ -325,15 +383,19 @@ func TestForwarding(t *testing.T) {
if n := e2.Drain(); n != 0 {
t.Fatalf("got e2.Drain() = %d, want = 0", n)
}
- } else {
- reply, ok := e2.Read()
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
+
+ reply, ok = e2.Read()
+ if test.expectPacketForwarded {
if !ok {
t.Fatal("expected ICMP Echo packet through outgoing NIC")
}
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv4Addr1),
- checker.DstAddr(remoteIPv4Addr2),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
checker.IPv4Options(test.forwardedOptions),
checker.ICMPv4(
@@ -347,6 +409,39 @@ func TestForwarding(t *testing.T) {
if n := e1.Drain(); n != 0 {
t.Fatalf("got e1.Drain() = %d, want = 0", n)
}
+ } else if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
+ }
+ return 0
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want {
+ t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
}
})
}
@@ -1241,7 +1336,6 @@ type fragmentInfo struct {
var fragmentationTests = []struct {
description string
mtu uint32
- gso *stack.GSO
transportHeaderLength int
payloadSize int
wantFragments []fragmentInfo
@@ -1249,7 +1343,6 @@ var fragmentationTests = []struct {
{
description: "No fragmentation",
mtu: 1280,
- gso: nil,
transportHeaderLength: 0,
payloadSize: 1000,
wantFragments: []fragmentInfo{
@@ -1259,7 +1352,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented",
mtu: 1280,
- gso: nil,
transportHeaderLength: 0,
payloadSize: 2000,
wantFragments: []fragmentInfo{
@@ -1270,7 +1362,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented with the minimum mtu",
mtu: header.IPv4MinimumMTU,
- gso: nil,
transportHeaderLength: 0,
payloadSize: 100,
wantFragments: []fragmentInfo{
@@ -1282,7 +1373,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented with mtu not a multiple of 8",
mtu: header.IPv4MinimumMTU + 1,
- gso: nil,
transportHeaderLength: 0,
payloadSize: 100,
wantFragments: []fragmentInfo{
@@ -1294,7 +1384,6 @@ var fragmentationTests = []struct {
{
description: "No fragmentation with big header",
mtu: 2000,
- gso: nil,
transportHeaderLength: 100,
payloadSize: 1000,
wantFragments: []fragmentInfo{
@@ -1302,20 +1391,8 @@ var fragmentationTests = []struct {
},
},
{
- description: "Fragmented with gso none",
- mtu: 1280,
- gso: &stack.GSO{Type: stack.GSONone},
- transportHeaderLength: 0,
- payloadSize: 1400,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1256, more: true},
- {offset: 1256, payloadSize: 144, more: false},
- },
- },
- {
description: "Fragmented with big header",
mtu: 1280,
- gso: nil,
transportHeaderLength: 100,
payloadSize: 1200,
wantFragments: []fragmentInfo{
@@ -1326,7 +1403,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented with MTU smaller than header",
mtu: 300,
- gso: nil,
transportHeaderLength: 1000,
payloadSize: 500,
wantFragments: []fragmentInfo{
@@ -1349,13 +1425,13 @@ func TestFragmentationWritePacket(t *testing.T) {
r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
source := pkt.Clone()
- err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Fatalf("r.WritePacket(_, _, _) = %s", err)
+ t.Fatalf("r.WritePacket(...): %s", err)
}
if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
@@ -1421,7 +1497,7 @@ func TestFragmentationWritePackets(t *testing.T) {
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
@@ -1528,7 +1604,7 @@ func TestFragmentationErrors(t *testing.T) {
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
- err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
@@ -2612,34 +2688,36 @@ func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2648,16 +2726,32 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2670,10 +2764,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2687,7 +2804,7 @@ func TestWriteStats(t *testing.T) {
writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
return nWritten, err
}
nWritten++
@@ -2697,7 +2814,7 @@ func TestWriteStats(t *testing.T) {
}, {
name: "WritePackets",
writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
},
},
}
@@ -2724,13 +2841,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
+ }
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
@@ -2995,12 +3115,14 @@ func TestCloseLocking(t *testing.T) {
nicID1 = 1
nicID2 = 2
- src = tcpip.Address("\x10\x00\x00\x01")
- dst = tcpip.Address("\x10\x00\x00\x02")
-
iterations = 1000
)
+ var (
+ src = tcptestutil.MustParse4("16.0.0.1")
+ dst = tcptestutil.MustParse4("16.0.0.2")
+ )
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index bb9a02ed0..db998e83e 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -66,5 +66,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index a142b76c1..ebb0b73df 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
return
}
+ // Keep needed information before trimming header.
+ p := hdr.TransportProtocol()
+ dstAddr := hdr.DestinationAddress()
+
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().TrimFront(header.IPv6MinimumSize)
- p := hdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6MinimumSize)
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// because they don't have the transport headers.
return
}
+ p = fragHdr.TransportProtocol()
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().TrimFront(header.IPv6FragmentHeaderSize)
- p = fragHdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
}
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -273,7 +276,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if iph.HopLimit() != header.MLDHopLimit {
return false
}
- if !header.IsV6LinkLocalAddress(iph.SourceAddress()) {
+ if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) {
return false
}
return true
@@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize)
networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
+ pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
@@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch header.ICMPv6(hdr).Code() {
+ code := header.ICMPv6(hdr).Code()
+ pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
@@ -564,7 +568,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
+ if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
sent.dropped.Increment()
return
}
@@ -704,7 +708,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: r.DefaultTTL(),
TOS: stack.DefaultTOS,
@@ -804,7 +808,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
routerAddr := srcAddr
// Is the IP Source Address a link-local address?
- if !header.IsV6LinkLocalAddress(routerAddr) {
+ if !header.IsV6LinkLocalUnicastAddress(routerAddr) {
// ...No, silently drop the packet.
received.invalid.Increment()
return
@@ -951,6 +955,7 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
// icmpReason is a marker interface for IPv6 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ isForwarding() bool
}
// icmpReasonParameterProblem is an error during processing of extension headers
@@ -982,6 +987,9 @@ type icmpReasonParameterProblem struct {
}
func (*icmpReasonParameterProblem) isICMPReason() {}
+func (*icmpReasonParameterProblem) isForwarding() bool {
+ return false
+}
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
@@ -989,12 +997,44 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+func (*icmpReasonPortUnreachable) isForwarding() bool {
+ return false
+}
+
+// icmpReasonNetUnreachable is an error where no route can be found to the
+// network of the final destination.
+type icmpReasonNetUnreachable struct{}
+
+func (*icmpReasonNetUnreachable) isICMPReason() {}
+
+func (*icmpReasonNetUnreachable) isForwarding() bool {
+ // If we hit a Network Unreachable error, then we also know we are
+ // operating as a router. As per RFC 4443 section 3.1:
+ //
+ // If the reason for the failure to deliver is lack of a matching
+ // entry in the forwarding node's routing table, the Code field is
+ // set to 0 (Network Unreachable).
+ return true
+}
+
// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
// transit to its final destination, as per RFC 4443 section 3.3.
type icmpReasonHopLimitExceeded struct{}
func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+func (*icmpReasonHopLimitExceeded) isForwarding() bool {
+ // If we hit a Hop Limit Exceeded error, then we know we are operating
+ // as a router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard
+ // the packet and originate an ICMPv6 Time Exceeded message with Code
+ // 0 to the source of the packet. This indicates either a routing
+ // loop or too small an initial Hop Limit value.
+ return true
+}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -1002,6 +1042,10 @@ type icmpReasonReassemblyTimeout struct{}
func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+func (*icmpReasonReassemblyTimeout) isForwarding() bool {
+ return false
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
@@ -1040,15 +1084,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return nil
}
- // If we hit a Hop Limit Exceeded error, then we know we are operating as a
- // router. As per RFC 4443 section 3.3:
- //
- // If a router receives a packet with a Hop Limit of zero, or if a
- // router decrements a packet's Hop Limit to zero, it MUST discard the
- // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
- // the source of the packet. This indicates either a routing loop or
- // too small an initial Hop Limit value.
- //
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
@@ -1058,7 +1093,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
// packet as "multicast addresses must not be used as source addresses in IPv6
// packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
+ if reason.isForwarding() || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -1147,6 +1182,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonHopLimitExceeded:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
@@ -1167,7 +1206,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
PayloadLen: dataRange.Size(),
}))
if err := route.WritePacket(
- nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: route.DefaultTTL(),
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 6a7705ed1..e457be3cf 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -81,7 +81,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
return nil
}
@@ -130,19 +130,19 @@ func (*testInterface) Spoofing() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
+func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
+func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var r stack.RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
- return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
+ return t.LinkEndpoint.WritePacket(r, protocol, pkt)
}
func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c6d9d8f0d..659057fa7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
// Snooping switches MUST manage multicast forwarding state based on MLD
// Report and Done messages sent with the unspecified address as the
// IPv6 source address.
- if header.IsV6LinkLocalAddress(addr) {
+ if header.IsV6LinkLocalUnicastAddress(addr) {
e.mu.mld.sendQueuedReports()
}
}
@@ -410,24 +410,52 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
//
// Must only be called when the forwarding status changes.
func (e *endpoint) transitionForwarding(forwarding bool) {
+ allRoutersGroups := [...]tcpip.Address{
+ header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ header.IPv6AllRoutersLinkLocalMulticastAddress,
+ header.IPv6AllRoutersSiteLocalMulticastAddress,
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
- if !e.Enabled() {
- return
- }
-
if forwarding {
- // When transitioning into an IPv6 router, host-only state (NDP discovered
- // routers, discovered on-link prefixes, and auto-generated addresses) is
- // cleaned up/invalidated and NDP router solicitations are stopped.
- e.mu.ndp.stopSolicitingRouters()
- e.mu.ndp.cleanupState(true /* hostOnly */)
+ // As per RFC 4291 section 2.8:
+ //
+ // A router is required to recognize all addresses that a host is
+ // required to recognize, plus the following addresses as identifying
+ // itself:
+ //
+ // o The All-Routers multicast addresses defined in Section 2.7.1.
+ //
+ // As per RFC 4291 section 2.7.1,
+ //
+ // All Routers Addresses: FF01:0:0:0:0:0:0:2
+ // FF02:0:0:0:0:0:0:2
+ // FF05:0:0:0:0:0:0:2
+ //
+ // The above multicast addresses identify the group of all IPv6 routers,
+ // within scope 1 (interface-local), 2 (link-local), or 5 (site-local).
+ for _, g := range allRoutersGroups {
+ if err := e.joinGroupLocked(g); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err))
+ }
+ }
} else {
- // When transitioning into an IPv6 host, NDP router solicitations are
- // started.
- e.mu.ndp.startSolicitingRouters()
+ for _, g := range allRoutersGroups {
+ switch err := e.leaveGroupLocked(g).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ }
+ }
}
+
+ e.mu.ndp.forwardingChanged(forwarding)
}
// Enable implements stack.NetworkEndpoint.
@@ -509,17 +537,7 @@ func (e *endpoint) Enable() tcpip.Error {
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
}
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyway.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !e.protocol.Forwarding() {
- e.mu.ndp.startSolicitingRouters()
- }
-
+ e.mu.ndp.startSolicitingRouters()
return nil
}
@@ -570,10 +588,10 @@ func (e *endpoint) disableLocked() {
return true
})
- e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.mu.ndp.cleanupState()
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
@@ -632,9 +650,9 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params
return nil
}
-func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool {
payload := pkt.TransportHeader().View().Size() + pkt.Data().Size()
- return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
+ return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU
}
// handleFragments fragments pkt and calls the handler function on each
@@ -642,7 +660,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta
// fragments left to be processed. The IP header must already be present in the
// original packet. The transport header protocol number is required to avoid
// parsing the IPv6 extension headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
networkHeader := header.IPv6(pkt.NetworkHeader().View())
// TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
@@ -681,7 +699,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* extensionHeaders */); err != nil {
return err
}
@@ -689,7 +707,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -712,10 +730,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
}
- return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+ return e.writePacket(r, pkt, params.Protocol, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop()&stack.PacketLoop != 0 {
// If the packet was generated by the stack (not a raw/packet endpoint
// where a packet may be written with the header included), then we can
@@ -726,6 +744,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
@@ -733,20 +760,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return err
}
- if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
+ if packetMustBeFragmented(pkt, networkMTU) {
+ sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
- return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ return e.nic.WritePacket(r, ProtocolNumber, fragPkt)
})
stats.PacketsSent.IncrementBy(uint64(sent))
stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil {
stats.OutgoingPacketErrors.Increment()
return err
}
@@ -756,7 +783,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("not implemented")
}
@@ -776,11 +803,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
- if packetMustBeFragmented(pb, networkMTU, gso) {
+ if packetMustBeFragmented(pb, networkMTU) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
+ if _, _, err := e.handleFragments(r, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
@@ -797,9 +824,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -820,14 +847,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
locallyDelivered++
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
- written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ written, err := e.nic.WritePackets(r, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(written))
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -863,12 +899,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return &tcpip.ErrMalformedHeader{}
}
- return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+ return e.writePacket(r, pkt, proto, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv6(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ // As per RFC 4291 section 2.5.6,
+ //
+ // Routers must not forward any packets with Link-Local source or
+ // destination addresses to other links.
+ if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
+ }
+
hopLimit := h.HopLimit()
if hopLimit <= 1 {
// As per RFC 4443 section 3.3,
@@ -878,11 +927,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// packet and originate an ICMPv6 Time Exceeded message with Code 0 to
// the source of the packet. This indicates either a routing loop or
// too small an initial Hop Limit value.
- return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -890,8 +942,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
@@ -906,10 +966,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ })); err != nil {
+ return &ip.ErrOther{Err: err}
+ }
+ return nil
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
@@ -958,7 +1021,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1010,8 +1073,21 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
stats.InvalidDestinationAddressesReceived.Increment()
return
}
-
- _ = e.forwardPacket(pkt)
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ e.stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ e.stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ e.stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ e.stats.ip.Forwarding.Unrouteable.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ e.stats.ip.Forwarding.Errors.Increment()
return
}
@@ -1028,7 +1104,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1571,7 +1647,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
var linkLocalAddr tcpip.Address
e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.IsAssigned(false /* allowExpired */) {
- if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) {
linkLocalAddr = addr
return false
}
@@ -1979,9 +2055,9 @@ func (p *protocol) Forwarding() bool {
// Returns true if the forwarding status was updated.
func (p *protocol) setForwarding(v bool) bool {
if v {
- return atomic.SwapUint32(&p.forwarding, 1) == 0
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
}
- return atomic.SwapUint32(&p.forwarding, 0) == 1
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
}
// SetForwarding implements stack.ForwardingNetworkProtocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index c206cebeb..4fbe39528 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2542,7 +2584,7 @@ func TestWriteStats(t *testing.T) {
writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
return nWritten, err
}
nWritten++
@@ -2552,7 +2594,7 @@ func TestWriteStats(t *testing.T) {
}, {
name: "WritePackets",
writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
},
},
}
@@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
+ }
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
@@ -2694,7 +2739,6 @@ type fragmentInfo struct {
var fragmentationTests = []struct {
description string
mtu uint32
- gso *stack.GSO
transHdrLen int
payloadSize int
wantFragments []fragmentInfo
@@ -2702,7 +2746,6 @@ var fragmentationTests = []struct {
{
description: "No fragmentation",
mtu: header.IPv6MinimumMTU,
- gso: nil,
transHdrLen: 0,
payloadSize: 1000,
wantFragments: []fragmentInfo{
@@ -2712,7 +2755,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented",
mtu: header.IPv6MinimumMTU,
- gso: nil,
transHdrLen: 0,
payloadSize: 2000,
wantFragments: []fragmentInfo{
@@ -2723,7 +2765,6 @@ var fragmentationTests = []struct {
{
description: "Fragmented with mtu not a multiple of 8",
mtu: header.IPv6MinimumMTU + 1,
- gso: nil,
transHdrLen: 0,
payloadSize: 2000,
wantFragments: []fragmentInfo{
@@ -2734,7 +2775,6 @@ var fragmentationTests = []struct {
{
description: "No fragmentation with big header",
mtu: 2000,
- gso: nil,
transHdrLen: 100,
payloadSize: 1000,
wantFragments: []fragmentInfo{
@@ -2742,20 +2782,8 @@ var fragmentationTests = []struct {
},
},
{
- description: "Fragmented with gso none",
- mtu: header.IPv6MinimumMTU,
- gso: &stack.GSO{Type: stack.GSONone},
- transHdrLen: 0,
- payloadSize: 1400,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 176, more: false},
- },
- },
- {
description: "Fragmented with big header",
mtu: header.IPv6MinimumMTU,
- gso: nil,
transHdrLen: 100,
payloadSize: 1200,
wantFragments: []fragmentInfo{
@@ -2778,7 +2806,7 @@ func TestFragmentationWritePacket(t *testing.T) {
source := pkt.Clone()
ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
- err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
@@ -2851,7 +2879,7 @@ func TestFragmentationWritePackets(t *testing.T) {
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
@@ -2955,7 +2983,7 @@ func TestFragmentationErrors(t *testing.T) {
pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
- err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
@@ -2991,36 +3019,94 @@ func TestForwarding(t *testing.T) {
}
remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+ unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
+ multicastIPv6Addr := tcpip.Address(net.ParseIP("ff00::").To16())
+ linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16())
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ expectPacketForwarded bool
+ countUnrouteablePackets uint64
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ icmpType header.ICMPv6Type
+ icmpCode header.ICMPv6Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
}{
{
name: "TTL of zero",
TTL: 0,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
},
{
name: "TTL of one",
TTL: 1,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
},
{
- name: "TTL of three",
- TTL: 3,
- expectErrorICMP: false,
+ name: "Network unreachable",
+ TTL: 2,
+ expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: unreachableIPv6Addr,
+ icmpType: header.ICMPv6DstUnreachable,
+ icmpCode: header.ICMPv6NetworkUnreachable,
+ expectPacketUnrouteableError: true,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Multicast destination",
+ TTL: 2,
+ countUnrouteablePackets: 1,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: linkLocalIPv6Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv6Addr,
+ destAddr: remoteIPv6Addr2,
+ expectLinkLocalSourceError: true,
},
}
@@ -3073,35 +3159,35 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmp,
- Src: remoteIPv6Addr1,
- Dst: remoteIPv6Addr2,
+ Src: test.sourceAddr,
+ Dst: test.destAddr,
}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
e1.InjectInbound(ProtocolNumber, requestPkt)
+ reply, ok := e1.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
- t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
checker.SrcAddr(ipv6Addr1.Address),
- checker.DstAddr(remoteIPv6Addr1),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(DefaultTTL),
checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6TimeExceeded),
- checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
+ checker.ICMPv6Type(test.icmpType),
+ checker.ICMPv6Code(test.icmpCode),
checker.ICMPv6Payload([]byte(hdr.View())),
),
)
@@ -3109,15 +3195,19 @@ func TestForwarding(t *testing.T) {
if n := e2.Drain(); n != 0 {
t.Fatalf("got e2.Drain() = %d, want = 0", n)
}
- } else {
- reply, ok := e2.Read()
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
+
+ reply, ok = e2.Read()
+ if test.expectPacketForwarded {
if !ok {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv6Addr1),
- checker.DstAddr(remoteIPv6Addr2),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6EchoRequest),
@@ -3129,6 +3219,35 @@ func TestForwarding(t *testing.T) {
if n := e1.Drain(); n != 0 {
t.Fatalf("got e1.Drain() = %d, want = 0", n)
}
+ } else if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
+ }
+ return 0
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index dd153466d..bc1af193c 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error)
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
- _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2710 section 5 page 10,
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ //
+ // MLD messages are never sent for multicast addresses whose scope is 0
+ // (reserved) or 1 (node-local).
+ if groupAddress == header.IPv6AllNodesMulticastAddress {
+ return false
+ }
+
+ scope := header.V6MulticastScope(groupAddress)
+ return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope
+}
+
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
//
@@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv6AllNodesMulticastAddress,
})
}
@@ -259,7 +277,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
}, extensionHeaders); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), ProtocolNumber, pkt); err != nil {
sentStats.dropped.Increment()
return false, err
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 85a8f9944..71d1c3e28 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -27,15 +27,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var (
+ linkLocalAddr = testutil.MustParse6("fe80::1")
+ globalAddr = testutil.MustParse6("a80::1")
+ globalMulticastAddr = testutil.MustParse6("ff05:100::2")
+
linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
)
@@ -93,7 +92,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
if p, ok := e.Read(); !ok {
t.Fatal("expected a done message to be sent")
} else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
}
}
@@ -354,10 +353,8 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho
}
func TestMLDPacketValidation(t *testing.T) {
- const (
- nicID = 1
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- )
+ const nicID = 1
+ linkLocalAddr2 := testutil.MustParse6("fe80::2")
tests := []struct {
name string
@@ -464,3 +461,141 @@ func TestMLDPacketValidation(t *testing.T) {
})
}
}
+
+func TestMLDSkipProtocol(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ group tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Reserverd0",
+ group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Interface Local",
+ group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Link Local",
+ group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Realm Local",
+ group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Admin Local",
+ group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Site Local",
+ group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(6)",
+ group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(7)",
+ group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Organization Local",
+ group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(9)",
+ group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(A)",
+ group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(B)",
+ group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(C)",
+ group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(D)",
+ group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Global",
+ group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "ReservedF",
+ group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil {
+ t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err)
+ }
+ if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group)
+ }
+
+ if !test.expectReport {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p)
+ }
+
+ return
+ }
+
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 536493f87..b29fed347 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -48,7 +48,7 @@ const (
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- defaultHandleRAs = true
+ defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
@@ -301,10 +301,60 @@ type NDPDispatcher interface {
OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
+var _ fmt.Stringer = HandleRAsConfiguration(0)
+
+// HandleRAsConfiguration enumerates when RAs may be handled.
+type HandleRAsConfiguration int
+
+const (
+ // HandlingRAsDisabled indicates that Router Advertisements will not be
+ // handled.
+ HandlingRAsDisabled HandleRAsConfiguration = iota
+
+ // HandlingRAsEnabledWhenForwardingDisabled indicates that router
+ // advertisements will only be handled when forwarding is disabled.
+ HandlingRAsEnabledWhenForwardingDisabled
+
+ // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always
+ // be handled, even when forwarding is enabled.
+ HandlingRAsAlwaysEnabled
+)
+
+// String implements fmt.Stringer.
+func (c HandleRAsConfiguration) String() string {
+ switch c {
+ case HandlingRAsDisabled:
+ return "HandlingRAsDisabled"
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return "HandlingRAsEnabledWhenForwardingDisabled"
+ case HandlingRAsAlwaysEnabled:
+ return "HandlingRAsAlwaysEnabled"
+ default:
+ return fmt.Sprintf("HandleRAsConfiguration(%d)", c)
+ }
+}
+
+// enabled returns true iff Router Advertisements may be handled given the
+// specified forwarding status.
+func (c HandleRAsConfiguration) enabled(forwarding bool) bool {
+ switch c {
+ case HandlingRAsDisabled:
+ return false
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return !forwarding
+ case HandlingRAsAlwaysEnabled:
+ return true
+ default:
+ panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c))
+ }
+}
+
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
// The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
+ //
+ // Ignored unless configured to handle Router Advertisements.
MaxRtrSolicitations uint8
// The amount of time between transmitting Router Solicitation messages.
@@ -318,8 +368,9 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements are processed.
- HandleRAs bool
+ // HandleRAs is the configuration for when Router Advertisements should be
+ // handled.
+ HandleRAs HandleRAsConfiguration
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
@@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -737,7 +789,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
prefix := opt.Subnet()
// Is the prefix a link-local?
- if header.IsV6LinkLocalAddress(prefix.ID()) {
+ if header.IsV6LinkLocalUnicastAddress(prefix.ID()) {
// ...Yes, skip as per RFC 4861 section 6.3.4,
// and RFC 4862 section 5.5.3.b (for SLAAC).
continue
@@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t
delete(tempAddrs, tempAddr)
}
-// removeSLAACAddresses removes all SLAAC addresses.
-//
-// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
-//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
- linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- var linkLocalPrefixes int
- for prefix, state := range ndp.slaacPrefixes {
- // RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them if we are cleaning up
- // host-only state.
- if keepLinkLocal && prefix == linkLocalSubnet {
- linkLocalPrefixes++
- continue
- }
-
- ndp.invalidateSLAACPrefix(prefix, state)
- }
-
- if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
- panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
- }
-}
-
// cleanupState cleans up ndp's state.
//
-// If hostOnly is true, then only host-specific state is cleaned up.
-//
// This function invalidates all discovered on-link prefixes, discovered
// routers, and auto-generated addresses.
//
-// If hostOnly is true, then the link-local auto-generated address aren't
-// invalidated as routers are also expected to generate a link-local address.
-//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
- ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
+func (ndp *ndpState) cleanupState() {
+ for prefix, state := range ndp.slaacPrefixes {
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
+// If ndp is not configured to handle Router Advertisements, routers will not
+// be solicited as there is no point soliciting routers if we don't handle their
+// advertisements.
+//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicitTimer.timer != nil {
@@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ return
+ }
+
// Calculate the random delay before sending our first RS, as per RFC
// 4861 section 6.3.7.
var delay time.Duration
@@ -1703,7 +1735,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// the unspecified address if no address is assigned
// to the sending interface.
localAddr := header.IPv6Any
- if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil {
localAddr = addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
}
@@ -1730,7 +1762,7 @@ func (ndp *ndpState) startSolicitingRouters() {
icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpData,
Src: localAddr,
- Dst: header.IPv6AllRoutersMulticastAddress,
+ Dst: header.IPv6AllRoutersLinkLocalMulticastAddress,
}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1739,14 +1771,14 @@ func (ndp *ndpState) startSolicitingRouters() {
})
sent := ndp.ep.stats.icmp.packetsSent
- if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), ProtocolNumber, pkt); err != nil {
sent.dropped.Increment()
// Don't send any more messages if we had an error.
remaining = 0
@@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() {
}
}
+// forwardingChanged handles a change in forwarding configuration.
+//
+// If transitioning to a host, router solicitation will be started. Otherwise,
+// router solicitation will be stopped if NDP is not configured to handle RAs
+// as a router.
+//
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) forwardingChanged(forwarding bool) {
+ if forwarding {
+ if ndp.configs.HandleRAs.enabled(forwarding) {
+ return
+ }
+
+ ndp.stopSolicitingRouters()
+ return
+ }
+
+ // Solicit routers when transitioning to a host.
+ //
+ // If the endpoint is not currently enabled, routers will be solicited when
+ // the endpoint becomes enabled (if it is still a host).
+ if ndp.ep.Enabled() {
+ ndp.startSolicitingRouters()
+ }
+}
+
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
@@ -1839,7 +1897,7 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL
}
sent := e.stats.icmp.packetsSent
- err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
+ err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt)
if err != nil {
sent.dropped.Increment()
} else {
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
index c2758352f..2f18f60e8 100644
--- a/pkg/tcpip/network/ipv6/stats.go
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -29,6 +29,10 @@ type Stats struct {
// ICMP holds ICMPv6 statistics.
ICMP tcpip.ICMPv6Stats
+
+ // UnhandledRouterAdvertisements is the number of Router Advertisements that
+ // were observed but not handled.
+ UnhandledRouterAdvertisements *tcpip.StatCounter
}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index ecd5003a7..1b96b1fb8 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -30,22 +30,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
defaultIPv4PrefixLength = 24
- linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
- ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
- ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
- ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
- ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
- ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
@@ -59,6 +50,19 @@ const (
)
var (
+ stackIPv4Addr = testutil.MustParse4("10.0.0.1")
+ linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
+ linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
+
+ ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
+ ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
+ ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
+ ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
+ ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
+ ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
+)
+
+var (
// unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
// NIC will wait before sending an unsolicited report after joining a
// multicast group, in deciseconds.
@@ -194,7 +198,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
} else {
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
}
// Should not send any more packets.
@@ -606,7 +610,7 @@ func TestMGPLeaveGroup(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
checkInitialGroups: checkInitialIPv6Groups,
},
@@ -1014,7 +1018,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
},
getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
t.Helper()
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 210262703..b7f6d52ae 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -21,6 +21,7 @@ go_test(
library = ":ports",
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 678199371..b5b013b64 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -17,6 +17,7 @@
package ports
import (
+ "math"
"math/rand"
"sync/atomic"
@@ -24,7 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-const anyIPAddress tcpip.Address = ""
+const (
+ firstEphemeral = 16000
+ anyIPAddress tcpip.Address = ""
+)
// Reservation describes a port reservation.
type Reservation struct {
@@ -220,10 +224,8 @@ type PortManager struct {
func NewPortManager() *PortManager {
return &PortManager{
allocatedPorts: make(map[portDescriptor]addrToDevice),
- // Match Linux's default ephemeral range. See:
- // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
- firstEphemeral: 32768,
- numEphemeral: 28232,
+ firstEphemeral: firstEphemeral,
+ numEphemeral: math.MaxUint16 - firstEphemeral + 1,
}
}
@@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint16(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rand.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
// portHint atomically reads and returns the pm.hint value.
-func (pm *PortManager) portHint() uint16 {
- return uint16(atomic.LoadUint32(&pm.hint))
+func (pm *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&pm.hint)
}
// incPortHint atomically increments pm.hint by 1.
@@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
@@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
- for i := uint16(0); i < count; i++ {
- port = first + (offset+i)%count
+func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ for i := uint32(0); i < uint32(count); i++ {
+ port := uint16(uint32(first) + (offset+i)%uint32(count))
ok, err := testPort(port)
if err != nil {
return 0, err
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 0f43dc8f8..6c4fb8c68 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,19 +15,23 @@
package ports
import (
+ "math"
"math/rand"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
+)
- fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
- fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
+var (
+ fakeIPAddress = testutil.MustParse4("8.8.8.8")
+ fakeIPAddress1 = testutil.MustParse4("8.8.8.9")
)
type portReserveTestAction struct {
@@ -479,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) {
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- portOffset := uint16(rand.Int31n(int32(numEphemeralPorts)))
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@@ -490,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) {
})
}
}
+
+// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a
+// port allocation failure.
+func TestOverflow(t *testing.T) {
+ // Use a small range and start at offsets that will cause an overflow.
+ count := uint16(50)
+ for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ {
+ reservedPorts := make(map[uint16]struct{})
+ // Ensure we can reserve everything in the allowed range.
+ for i := uint16(0); i < count; i++ {
+ port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) {
+ if _, ok := reservedPorts[port]; !ok {
+ reservedPorts[port] = struct{}{}
+ return true, nil
+ }
+ return false, nil
+ })
+ if err != nil {
+ t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts))
+ }
+ if port < firstEphemeral || port > firstEphemeral+count {
+ t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index dc37e61a4..a6c877158 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -58,6 +58,9 @@ type SocketOptionsHandler interface {
// changed. The handler is invoked with the new value for the socket send
// buffer size. It also returns the newly set value.
OnSetSendBufferSize(v int64) (newSz int64)
+
+ // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE.
+ OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -99,6 +102,11 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
+func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
+ return v
+}
+
// StackHandler holds methods to access the stack options. These must be
// implemented by the stack.
type StackHandler interface {
@@ -207,6 +215,14 @@ type SocketOptions struct {
// sendBufferSize determines the send buffer size for this socket.
sendBufferSize int64
+ // getReceiveBufferLimits provides the handler to get the min, default and
+ // max size for receive buffer. It is initialized at the creation time and
+ // will not change.
+ getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
+
+ // receiveBufferSize determines the receive buffer size for this socket.
+ receiveBufferSize int64
+
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -217,10 +233,11 @@ type SocketOptions struct {
// InitHandler initializes the handler. This must be called before using the
// socket options utility.
-func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) {
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) {
so.handler = handler
so.stackHandler = stack
so.getSendBufferLimits = getSendBufferLimits
+ so.getReceiveBufferLimits = getReceiveBufferLimits
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -632,3 +649,42 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
newSz := so.handler.OnSetSendBufferSize(v)
atomic.StoreInt64(&so.sendBufferSize, newSz)
}
+
+// GetReceiveBufferSize gets value for SO_RCVBUF option.
+func (so *SocketOptions) GetReceiveBufferSize() int64 {
+ return atomic.LoadInt64(&so.receiveBufferSize)
+}
+
+// SetReceiveBufferSize sets value for SO_RCVBUF option.
+func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
+ if !notify {
+ atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize)
+ return
+ }
+
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ v := receiveBufferSize
+ ss := so.getReceiveBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ if v > max {
+ v = max
+ }
+
+ // Multiply it by factor of 2.
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
+ }
+ } else {
+ v = math.MaxInt32
+ }
+
+ oldSz := atomic.LoadInt64(&so.receiveBufferSize)
+ // Notify endpoint about change in buffer size.
+ newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
+ atomic.StoreInt64(&so.receiveBufferSize, newSz)
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 49362333a..2bd6a67f5 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -45,6 +45,7 @@ go_library(
"addressable_endpoint_state.go",
"conntrack.go",
"headertype_string.go",
+ "hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",
@@ -66,6 +67,7 @@ go_library(
"stack.go",
"stack_global_state.go",
"stack_options.go",
+ "tcp.go",
"transport_demuxer.go",
"tuple_list.go",
],
@@ -115,6 +117,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -139,6 +142,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 3f083928f..5720e7543 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
+ "fmt"
"sync"
"time"
@@ -29,7 +30,7 @@ import (
// The connection is created for a packet if it does not exist. Every
// connection contains two tuples (original and reply). The tuples are
// manipulated if there is a matching NAT rule. The packet is modified by
-// looking at the tuples in the Prerouting and Output hooks.
+// looking at the tuples in each hook.
//
// Currently, only TCP tracking is supported.
@@ -46,12 +47,14 @@ const (
)
// Manipulation type for the connection.
+// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
+// DNAT at the same time.
type manipType int
const (
manipNone manipType = iota
- manipDstPrerouting
- manipDstOutput
+ manipSource
+ manipDestination
)
// tuple holds a connection's identifying and manipulating data in one
@@ -108,6 +111,7 @@ type conn struct {
reply tuple
// manip indicates if the packet should be manipulated. It is immutable.
+ // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
manip manipType
// tcbHook indicates if the packet is inbound or outbound to
@@ -124,6 +128,18 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
}, nil
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
func (ct *ConnTrack) init() {
ct.mu.Lock()
defer ct.mu.Unlock()
@@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
return nil
}
- // Create a new connection and change the port as per the iptables
- // rule. This tuple will be used to manipulate the packet in
- // handlePacket.
replyTID := tid.reply()
replyTID.srcAddr = address
replyTID.srcPort = port
- var manip manipType
- switch hook {
- case Prerouting:
- manip = manipDstPrerouting
- case Output:
- manip = manipDstOutput
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
}
- conn := newConn(tid, replyTID, manip, hook)
+ conn = newConn(tid, replyTID, manipDestination, hook)
+ ct.insertConn(conn)
+ return conn
+}
+
+func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Input && hook != Postrouting {
+ return nil
+ }
+
+ replyTID := tid.reply()
+ replyTID.dstAddr = address
+ replyTID.dstPort = port
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
+ }
+ conn = newConn(tid, replyTID, manipSource, hook)
ct.insertConn(conn)
return conn
}
@@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Now that we hold the locks, ensure the tuple hasn't been inserted by
// another thread.
+ // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
alreadyInserted := false
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
if other.tupleID == conn.original.tupleID {
@@ -343,95 +369,17 @@ func (ct *ConnTrack) insertConn(conn *conn) {
}
}
-// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
-func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For prerouting redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified.
- switch dir {
- case dirOriginal:
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- case dirReply:
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
- // on inbound packets, so we don't recalculate them. However, we should
- // support cases when they are validated, e.g. when we can't offload
- // receive checksumming.
-
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
-// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For output redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified. For prerouting redirection, we only reach this point
- // when replying, so packet sources are modified.
- if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- } else {
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- if gso != nil && gso.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
- } else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
- }
-
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
// handlePacket will manipulate the port and address of the packet if the
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if pkt.NatDone {
return false
}
- if hook != Prerouting && hook != Output {
+ switch hook {
+ case Prerouting, Input, Output, Postrouting:
+ default:
return false
}
@@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
conn, dir := ct.connFor(pkt)
- // Connection or Rule not found for the packet.
+ // Connection not found for the packet.
if conn == nil {
- return true
+ // If this is the last hook in the data path for this packet (Input if
+ // incoming, Postrouting if outgoing), indicate that a connection should be
+ // inserted by the end of this hook.
+ return hook == Input || hook == Postrouting
}
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return false
}
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
+
+ switch hook {
+ case Prerouting, Output:
+ if conn.manip == manipDestination {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetDestinationPort(conn.reply.srcPort)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ tcpHeader.SetSourcePort(conn.original.dstPort)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+ pkt.NatDone = true
+ }
+ case Input, Postrouting:
+ if conn.manip == manipSource {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetSourcePort(conn.reply.dstPort)
+ netHeader.SetSourceAddress(conn.reply.dstAddr)
+ case dirReply:
+ tcpHeader.SetDestinationPort(conn.original.srcPort)
+ netHeader.SetDestinationAddress(conn.original.srcAddr)
+ }
+ pkt.NatDone = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+ if !pkt.NatDone {
+ return false
+ }
+
switch hook {
- case Prerouting:
- handlePacketPrerouting(pkt, conn, dir)
- case Output:
- handlePacketOutput(pkt, conn, gso, r, dir)
+ case Prerouting, Input:
+ case Output, Postrouting:
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
- pkt.NatDone = true
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
@@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
if conn == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip == manipNone {
- // Unmanipulated connection.
+ } else if conn.manip != manipDestination {
+ // Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 16ee75bc4..7d3725681 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: vv.ToView().ToVectorisedView(),
})
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
_ = r.WriteHeaderIncludedPacket(pkt)
}
@@ -117,7 +117,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
@@ -125,11 +125,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress()[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -139,7 +139,7 @@ func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *Packet
return &tcpip.ErrMalformedHeader{}
}
- return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -264,6 +264,8 @@ type fwdTestPacketInfo struct {
Pkt *PacketBuffer
}
+var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)
+
type fwdTestLinkEndpoint struct {
dispatcher NetworkDispatcher
mtu uint32
@@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
return caps | CapabilityResolutionRequired
}
-// GSOMaxSize returns the maximum GSO packet size.
-func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
- return 1 << 15
-}
-
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
@@ -322,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -338,10 +335,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, gso, protocol, pkt)
+ e.WritePacket(r, protocol, pkt)
n++
}
diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go
new file mode 100644
index 000000000..3dc8a7b02
--- /dev/null
+++ b/pkg/tcpip/stack/hook_string.go
@@ -0,0 +1,41 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type Hook ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Prerouting-0]
+ _ = x[Input-1]
+ _ = x[Forward-2]
+ _ = x[Output-3]
+ _ = x[Postrouting-4]
+ _ = x[NumHooks-5]
+}
+
+const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
+
+var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
+
+func (i Hook) String() string {
+ if i >= Hook(len(_Hook_index)-1) {
+ return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 52890f6eb..e2894c548 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: {MangleID, NATID},
- Input: {NATID, FilterID},
- Output: {MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
+ Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -266,12 +267,12 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
// which address can be gathered. Currently, address is only needed for
// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -285,7 +286,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
// Packets are manipulated only if connection and matching
// NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -302,7 +303,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -313,7 +314,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -385,10 +386,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
+ if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +409,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +430,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,7 +456,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -478,7 +479,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
+ return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 0e8b90c9b..2812c89aa 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -103,7 +103,7 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -174,7 +174,85 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, gso, r)
+ ct.handlePacket(pkt, hook, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
+
+// SNATTarget modifies the source port/IP in the outgoing packets.
+type SNATTarget struct {
+ Addr tcpip.Address
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ switch hook {
+ case Postrouting, Input:
+ case Prerouting, Output, Forward:
+ panic(fmt.Sprintf("%s not supported", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetChecksum(0)
+ udpHeader.SetSourcePort(st.Port)
+ netHeader := pkt.Network()
+ netHeader.SetSourceAddress(st.Addr)
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.RequiresTXTransportChecksum() {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
default:
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index b0d84befb..4631ab93f 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -345,5 +345,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 14124ae66..c585b81b2 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -33,15 +33,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
+var (
+ addr1 = testutil.MustParse6("a00::1")
+ addr2 = testutil.MustParse6("a00::2")
+ addr3 = testutil.MustParse6("a00::3")
+)
+
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
@@ -1142,57 +1146,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
-// TestNoRouterDiscovery tests that router discovery will not be performed if
-// configured not to.
-func TestNoRouterDiscovery(t *testing.T) {
- // Being configured to discover routers means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // router discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+func TestDynamicConfigurationsDisabled(t *testing.T) {
+ const (
+ nicID = 1
+ maxRtrSolicitDelay = time.Second
+ )
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ prefix := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse6("102:304:506:708::"),
+ PrefixLen: 64,
+ }
- // Rx an RA with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router when configured not to")
- default:
+ tests := []struct {
+ name string
+ config func(bool) ipv6.NDPConfigurations
+ ra *stack.PacketBuffer
+ }{
+ {
+ name: "No Router Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
+ },
+ ra: raBuf(llAddr2, 1000),
+ },
+ {
+ name: "No Prefix Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0),
+ },
+ {
+ name: "No Autogenerate Addresses",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Being configured to discover routers/prefixes or auto-generate
+ // addresses means RAs must be handled, and router/prefix discovery or
+ // SLAAC must be enabled.
+ //
+ // This tests all possible combinations of the configurations where
+ // router/prefix discovery or SLAAC are disabled.
+ for i := 0; i < 7; i++ {
+ handle := ipv6.HandlingRAsDisabled
+ if i&1 != 0 {
+ handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled
+ }
+ enable := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ ndpConfigs := test.config(enable)
+ ndpConfigs.HandleRAs = handle
+ ndpConfigs.MaxRtrSolicitations = 1
+ ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay
+ ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
+ })
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding
+ ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err)
+ }
+ stats := ep.Stats()
+ v6Stats, ok := stats.(*ipv6.Stats)
+ if !ok {
+ t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats)
+ }
+
+ // Make sure that when handling RAs are enabled, we solicit routers.
+ clock.Advance(maxRtrSolicitDelay)
+ if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want)
+ }
+ if handleRAsDisabled {
+ if p, ok := e.Read(); ok {
+ t.Errorf("unexpectedly got a packet = %#v", p)
+ }
+ } else if p, ok := e.Read(); !ok {
+ t.Error("expected router solicitation packet")
+ } else if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ } else {
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(nil)),
+ )
+ }
+
+ // Make sure we do not discover any routers or prefixes, or perform
+ // SLAAC on reception of an RA.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone())
+ // Make sure that the unhandled RA stat is only incremented when
+ // handling RAs is disabled.
+ if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e)
+ default:
+ }
+ })
}
})
}
}
+func boolToUint64(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// discovered flag set to discovered.
func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
}
+func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
+ tests := [...]struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ forwarding bool
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding disabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding enabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f(t, test.handleRAs, test.forwarding)
+ })
+ }
+}
+
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
@@ -1203,7 +1348,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1237,103 +1382,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
}
- default:
- t.Fatal("expected router discovery event")
}
- }
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for router discovery event")
}
- }
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
-
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
+ // Rx an RA from lladdr2 with lesser lifetime.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ })
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1347,7 +1498,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1386,57 +1537,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
}
-// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
-// configured not to.
-func TestNoPrefixDiscovery(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
-
- // Being configured to discover prefixes means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // prefix discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
-
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for prefix on nic with ID 1, and the
// discovered flag set to discovered.
func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
@@ -1455,8 +1555,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1494,87 +1593,93 @@ func TestPrefixDiscovery(t *testing.T) {
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
}
- default:
- t.Fatal("expected prefix discovery event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- default:
- }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- expectPrefixEvent(subnet1, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- expectPrefixEvent(subnet2, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
- // Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- expectPrefixEvent(subnet3, true)
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
- // Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- expectPrefixEvent(subnet1, false)
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
- // Receive an RA with prefix2 in a PI with lesser lifetime.
- lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly received prefix event when updating lifetime")
- default:
- }
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
- // Wait for prefix2's most recent invalidation job plus some buffer to
- // expire.
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for prefix discovery event")
- }
- // Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- expectPrefixEvent(subnet3, false)
+ // Wait for prefix2's most recent invalidation job plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+ })
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
@@ -1590,7 +1695,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}()
prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
}
subnet := prefix.Subnet()
@@ -1603,7 +1708,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1688,7 +1793,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: false,
DiscoverOnLinkPrefixes: true,
},
@@ -1753,53 +1858,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
return containsAddr(list, protocolAddress)
}
-// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
-func TestNoAutoGenAddr(t *testing.T) {
- prefix, _, _ := prefixSubnetAddr(0, "")
-
- // Being configured to auto-generate addresses means handle and
- // autogen are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, autogen =
- // true and forwarding = false (the required configuration to do
- // SLAAC) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- autogen := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
-
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for addr on nic with ID 1, and the
// event type is set to eventType.
func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
@@ -1808,7 +1866,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr2(t *testing.T) {
+func TestAutoGenAddr(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
@@ -1820,96 +1878,102 @@ func TestAutoGenAddr2(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
- default:
- t.Fatal("expected addr auto gen event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- default:
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- default:
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
- // Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- default:
- }
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
- // Wait for addr of prefix1 to be invalidated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+ })
}
func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
@@ -1997,7 +2061,7 @@ func TestAutoGenTempAddr(t *testing.T) {
RetransmitTimer: test.retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2298,7 +2362,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2385,7 +2449,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2534,7 +2598,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2735,7 +2799,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
Clock: clock,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
@@ -2880,7 +2944,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: ndpDisp,
@@ -3347,7 +3411,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3490,7 +3554,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3557,7 +3621,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3723,7 +3787,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3805,7 +3869,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3969,7 +4033,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
@@ -3996,7 +4060,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Temporary address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -4146,7 +4210,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4274,7 +4338,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4480,7 +4544,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4531,7 +4595,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4625,8 +4689,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
}
-// TestCleanupNDPState tests that all discovered routers and prefixes, and
-// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
+ const (
+ lifetimeSeconds = 999
+ nicID = 1
+ )
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenLinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
+ if err := s.CreateNIC(nicID, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID)
+ }
+
+ prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1)
+ e1.InjectInbound(
+ header.IPv6ProtocolNumber,
+ raBufWithPI(
+ llAddr3,
+ lifetimeSeconds,
+ prefix,
+ true, /* onLink */
+ true, /* auto */
+ lifetimeSeconds,
+ lifetimeSeconds,
+ ),
+ )
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID)
+ }
+
+ // Enabling or disabling forwarding should not invalidate discovered prefixes
+ // or routers, or auto-generated address.
+ for _, forwarding := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) {
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpected router event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpected prefix event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpected auto-gen addr event = %#v", e)
+ default:
+ }
+ })
+ }
+}
+
func TestCleanupNDPState(t *testing.T) {
const (
lifetimeSeconds = 5
@@ -4655,18 +4821,6 @@ func TestCleanupNDPState(t *testing.T) {
maxAutoGenAddrEvents int
skipFinalAddrCheck bool
}{
- // A NIC should still keep its auto-generated link-local address when
- // becoming a router.
- {
- name: "Enable forwarding",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(ipv6.ProtocolNumber, true)
- },
- keepAutoGenLinkLocal: true,
- maxAutoGenAddrEvents: 4,
- },
-
// A NIC should cleanup all NDP state when it is disabled.
{
name: "Disable NIC",
@@ -4718,7 +4872,7 @@ func TestCleanupNDPState(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
DiscoverOnLinkPrefixes: true,
AutoGenGlobalAddresses: true,
@@ -4991,7 +5145,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -5182,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) {
},
}
+ subTests := []struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ afterFirstRS func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ afterFirstRS: func(*testing.T, *stack.Stack) {},
+ },
+
+ // Enabling forwarding when RAs are always configured to be handled
+ // should not stop router solicitations.
+ {
+ name: "Handle RAs always",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ afterFirstRS: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ },
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
- }
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
+ }
- clock.Advance(timeout)
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected router solicitation packet")
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: subTest.handleRAs,
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- clock.Advance(timeout)
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpectedly got a packet = %#v", p)
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay)
- remaining--
- }
+ subTest.afterFirstRS(t, s)
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
- waitForPkt(time.Nanosecond)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt)
- }
- }
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt)
+ }
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt)
- } else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay)
- }
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
- if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
}
})
}
@@ -5362,13 +5547,14 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
MaxRtrSolicitations: maxRtrSolicitations,
RtrSolicitationInterval: interval,
MaxRtrSolicitationDelay: delay,
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 48bb75e2f..9821a18d3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
func BenchmarkCacheClear(b *testing.B) {
b.StopTimer()
config := DefaultNUDConfigurations()
- clock := &tcpip.StdClock{}
+ clock := tcpip.NewStdClock()
linkRes := newTestNeighborResolver(nil, config, clock)
linkRes.delay = 0
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index bb2b2d705..1d39ee73d 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -26,14 +26,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
entryTestNICID tcpip.NICID = 1
- entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
@@ -44,6 +43,11 @@ const (
entryTestNetDefaultMTU = 65536
)
+var (
+ entryTestAddr1 = testutil.MustParse6("a::1")
+ entryTestAddr2 = testutil.MustParse6("a::2")
+)
+
// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
// time.
func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ca15c0691..8d615500f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -316,30 +316,30 @@ func (n *nic) IsLoopback() bool {
}
// WritePacket implements NetworkLinkEndpoint.
-func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
- _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+func (n *nic) WritePacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, protocol, pkt)
return err
}
-func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) writePacketBuffer(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
switch pkt := pkt.(type) {
case *PacketBuffer:
- if err := n.writePacket(r, gso, protocol, pkt); err != nil {
+ if err := n.writePacket(r, protocol, pkt); err != nil {
return 0, err
}
return 1, nil
case *PacketBufferList:
- return n.writePackets(r, gso, protocol, *pkt)
+ return n.writePackets(r, protocol, *pkt)
default:
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
}
-func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) enqueuePacketBuffer(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
routeInfo, _, err := r.resolvedFields(nil)
switch err.(type) {
case nil:
- return n.writePacketBuffer(routeInfo, gso, protocol, pkt)
+ return n.writePacketBuffer(routeInfo, protocol, pkt)
case *tcpip.ErrWouldBlock:
// As per relevant RFCs, we should queue packets while we wait for link
// resolution to complete.
@@ -358,28 +358,27 @@ func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt
// SHOULD be limited to some small value. When a queue overflows, the new
// arrival SHOULD replace the oldest entry. Once address resolution
// completes, the node transmits any queued packets.
- return n.linkResQueue.enqueue(r, gso, protocol, pkt)
+ return n.linkResQueue.enqueue(r, protocol, pkt)
default:
return 0, err
}
}
// WritePacketToRemote implements NetworkInterface.
-func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
var r RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
- return n.writePacket(r, gso, protocol, pkt)
+ return n.writePacket(r, protocol, pkt)
}
-func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
pkt.EgressRoute = r
- pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
- if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -389,18 +388,17 @@ func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
}
// WritePackets implements NetworkLinkEndpoint.
-func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
+func (n *nic) WritePackets(r *Route, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return n.enqueuePacketBuffer(r, protocol, &pkts)
}
-func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
+func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
- pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
}
- writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index c0f956e53..8a3005295 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -65,12 +65,12 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
}
// WritePacket implements NetworkEndpoint.WritePacket.
-func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
+func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
return nil
}
// WritePackets implements NetworkEndpoint.WritePackets.
-func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
+func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
// Our tests don't use this so we don't support it.
return 0, &tcpip.ErrNotSupported{}
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 8f288675d..9527416cf 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -103,7 +103,7 @@ type PacketBuffer struct {
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
EgressRoute RouteInfo
- GSOOptions *GSO
+ GSOOptions GSO
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
@@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
- return NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
})
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ if pk.NatDone {
+ newPk.NatDone = true
+ }
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
@@ -355,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) {
return d.pk.data.PullUp(size)
}
-// TrimFront removes count from the beginning of d. It panics if count >
-// d.Size().
-func (d PacketData) TrimFront(count int) {
+// DeleteFront removes count from the beginning of d. It panics if count >
+// d.Size(). All backing storage references after the front of the d are
+// invalidated.
+func (d PacketData) DeleteFront(count int) {
d.pk.data.TrimFront(count)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 6728370c3..bd4eb4fed 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, test.data)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
- concatViews(test.link, test.network, test.transport, test.data))
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(test.link, test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(test.transport, test.data))
+ // Check the after state.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.link,
+ network: test.network,
+ transport: test.transport,
+ data: test.data,
+ })
})
}
}
@@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) {
if got, want := pk.Size(), len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- // After state of pk.
- var (
- link = test.data[:test.link]
- network = test.data[test.link:][:test.network]
- transport = test.data[test.link+test.network:][:test.transport]
- payload = test.data[allHdrSize:]
- )
- checkData(t, pk, payload)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(link, network, transport, payload))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(network, transport, payload))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(transport, payload))
+ // Check the after state of pk.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.data[:test.link],
+ network: test.data[test.link:][:test.network],
+ transport: test.data[test.link+test.network:][:test.transport],
+ data: test.data[allHdrSize:],
+ })
})
}
}
@@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
})
}
+// This is a very obscure use-case seen in the code that verifies packets
+// before sending them out. It tries to parse the headers to verify.
+// PacketHeader was initially not designed to mix Push() and Consume(), but it
+// works and it's been relied upon. Include a test here.
+func TestPacketHeaderPushConsumeMixed(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := append([]byte(nil), network...)
+ initData = append(initData, data...)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Consume network header
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(network))
+ if !ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network))
+ }
+ checkViewEqual(t, "gotNetwork", gotNetwork, network)
+
+ // 2. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ network: network,
+ data: data,
+ })
+}
+
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
@@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // TrimFront
+ // DeleteFront
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().TrimFront(n)
+ pkt.Data().DeleteFront(n)
checkData(t, pkt, []byte(tc.data)[n:])
})
@@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) {
}
}
+type packetContents struct {
+ link buffer.View
+ network buffer.View
+ transport buffer.View
+ data buffer.View
+}
+
+func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) {
+ t.Helper()
+ // Headers.
+ checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link)
+ checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network)
+ checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport)
+ // Data.
+ checkData(t, pk, want.data)
+ // Whole packet.
+ checkViewEqual(t, prefix+"pk.Views()",
+ concatViews(pk.Views()...),
+ concatViews(want.link, want.network, want.transport, want.data))
+ // PayloadSince.
+ checkViewEqual(t, prefix+"PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(want.link, want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(want.transport, want.data))
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, data)
- checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
- // Check the initial values for each header.
- checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
- checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
- checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
- // Check the initial valies for PayloadSince.
- checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()), data)
+ checkPacketContents(t, "Initial ", pk, packetContents{
+ data: data,
+ })
}
func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index e936aa728..13e8907ec 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -46,7 +46,6 @@ func (p *PacketBufferList) len() int {
type pendingPacket struct {
routeInfo RouteInfo
- gso *GSO
proto tcpip.NetworkProtocolNumber
pkt pendingPacketBuffer
}
@@ -119,7 +118,7 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi
// If the maximum number of pending resolutions is reached, the packets
// associated with the oldest link resolution will be dequeued as if they failed
// link resolution.
-func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
f.mu.Lock()
// Make sure we attempt resolution while holding f's lock so that we avoid
// a race where link resolution completes before we enqueue the packets.
@@ -137,7 +136,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N
// The route resolved immediately, so we don't need to wait for link
// resolution to send the packet.
f.mu.Unlock()
- return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt)
+ return f.nic.writePacketBuffer(routeInfo, proto, pkt)
case *tcpip.ErrWouldBlock:
// We need to wait for link resolution to complete.
default:
@@ -150,7 +149,6 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N
packets, ok := f.mu.packets[ch]
packets = append(packets, pendingPacket{
routeInfo: routeInfo,
- gso: gso,
proto: proto,
pkt: pkt,
})
@@ -211,7 +209,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l
for _, p := range packets {
if err == nil {
p.routeInfo.RemoteLinkAddress = linkAddr
- _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ _, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt)
} else {
f.incrementOutgoingPacketErrors(p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ff3a385e1..e26225552 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -537,14 +537,14 @@ type NetworkInterface interface {
CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
// WritePacketToRemote writes the packet to the given remote link address.
- WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePacket writes a packet with the given protocol through the given
// route.
//
// WritePacket takes ownership of the packet buffer. The packet buffer's
// network and transport header must be set.
- WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol through the given
// route. Must not be called with an empty list of packet buffers.
@@ -554,7 +554,7 @@ type NetworkInterface interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+ WritePackets(*Route, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
// HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP
// request or NDP Neighbor Solicitation).
@@ -610,12 +610,12 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It takes ownership of pkt. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
+ WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length. It takes ownership of pkts and
// underlying packets.
- WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
+ WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address. It takes ownership of pkt.
@@ -756,11 +756,6 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityHardwareGSO
-
- // CapabilitySoftwareGSO indicates the link endpoint supports of sending
- // multiple packets using a single call (LinkEndpoint.WritePackets).
- CapabilitySoftwareGSO
)
// NetworkLinkEndpoint is a data-link layer that supports sending network
@@ -832,7 +827,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacket(RouteInfo, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol and route. Must not be
// called with an empty list of packet buffers.
@@ -842,7 +837,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+ WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -1047,10 +1042,29 @@ type GSO struct {
MaxSize uint32
}
+// SupportedGSO returns the type of segmentation offloading supported.
+type SupportedGSO int
+
+const (
+ // GSONotSupported indicates that segmentation offloading is not supported.
+ GSONotSupported SupportedGSO = iota
+
+ // HWGSOSupported indicates that segmentation offloading may be performed by
+ // the hardware.
+ HWGSOSupported
+
+ // SWGSOSupported indicates that segmentation offloading may be performed in
+ // software.
+ SWGSOSupported
+)
+
// GSOEndpoint provides access to GSO properties.
type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
+
+ // SupportedGSO returns the supported segmentation offloading.
+ SupportedGSO() SupportedGSO
}
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 39344808d..8a044c073 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
localAddr = addressEndpoint.AddressWithPrefix().Address
}
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) {
addressEndpoint.DecRef()
return nil
}
@@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool {
// HasSoftwareGSOCapability returns true if the route supports software GSO.
func (r *Route) HasSoftwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == SWGSOSupported
+ }
+ return false
}
// HasHardwareGSOCapability returns true if the route supports hardware GSO.
func (r *Route) HasHardwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == HWGSOSupported
+ }
+ return false
}
// HasSaveRestoreCapability returns true if the route supports save/restore.
@@ -448,22 +454,22 @@ func (r *Route) isValidForOutgoingRLocked() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
+func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
return &tcpip.ErrInvalidEndpointState{}
}
- return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
-func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (r *Route) WritePackets(pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
if !r.isValidForOutgoing() {
return 0, &tcpip.ErrInvalidEndpointState{}
}
- return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 931a97ddc..436392f23 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -35,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -56,306 +55,6 @@ type transportProtocolState struct {
defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
-// TCPProbeFunc is the expected function type for a TCP probe function to be
-// passed to stack.AddTCPProbe.
-type TCPProbeFunc func(s TCPEndpointState)
-
-// TCPCubicState is used to hold a copy of the internal cubic state when the
-// TCPProbeFunc is invoked.
-type TCPCubicState struct {
- WLastMax float64
- WMax float64
- T time.Time
- TimeSinceLastCongestion time.Duration
- C float64
- K float64
- Beta float64
- WC float64
- WEst float64
-}
-
-// TCPRACKState is used to hold a copy of the internal RACK state when the
-// TCPProbeFunc is invoked.
-type TCPRACKState struct {
- XmitTime time.Time
- EndSequence seqnum.Value
- FACK seqnum.Value
- RTT time.Duration
- Reord bool
- DSACKSeen bool
- ReoWnd time.Duration
- ReoWndIncr uint8
- ReoWndPersist int8
- RTTSeq seqnum.Value
-}
-
-// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
-type TCPEndpointID struct {
- // LocalPort is the local port associated with the endpoint.
- LocalPort uint16
-
- // LocalAddress is the local [network layer] address associated with
- // the endpoint.
- LocalAddress tcpip.Address
-
- // RemotePort is the remote port associated with the endpoint.
- RemotePort uint16
-
- // RemoteAddress it the remote [network layer] address associated with
- // the endpoint.
- RemoteAddress tcpip.Address
-}
-
-// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
-// TCP endpoint.
-type TCPFastRecoveryState struct {
- // Active if true indicates the endpoint is in fast recovery.
- Active bool
-
- // First is the first unacknowledged sequence number being recovered.
- First seqnum.Value
-
- // Last is the 'recover' sequence number that indicates the point at
- // which we should exit recovery barring any timeouts etc.
- Last seqnum.Value
-
- // MaxCwnd is the maximum value we are permitted to grow the congestion
- // window during recovery. This is set at the time we enter recovery.
- MaxCwnd int
-
- // HighRxt is the highest sequence number which has been retransmitted
- // during the current loss recovery phase.
- // See: RFC 6675 Section 2 for details.
- HighRxt seqnum.Value
-
- // RescueRxt is the highest sequence number which has been
- // optimistically retransmitted to prevent stalling of the ACK clock
- // when there is loss at the end of the window and no new data is
- // available for transmission.
- // See: RFC 6675 Section 2 for details.
- RescueRxt seqnum.Value
-}
-
-// TCPReceiverState holds a copy of the internal state of the receiver for
-// a given TCP endpoint.
-type TCPReceiverState struct {
- // RcvNxt is the TCP variable RCV.NXT.
- RcvNxt seqnum.Value
-
- // RcvAcc is the TCP variable RCV.ACC.
- RcvAcc seqnum.Value
-
- // RcvWndScale is the window scaling to use for inbound segments.
- RcvWndScale uint8
-
- // PendingBufUsed is the number of bytes pending in the receive
- // queue.
- PendingBufUsed int
-}
-
-// TCPSenderState holds a copy of the internal state of the sender for
-// a given TCP Endpoint.
-type TCPSenderState struct {
- // LastSendTime is the time at which we sent the last segment.
- LastSendTime time.Time
-
- // DupAckCount is the number of Duplicate ACK's received.
- DupAckCount int
-
- // SndCwnd is the size of the sending congestion window in packets.
- SndCwnd int
-
- // Ssthresh is the slow start threshold in packets.
- Ssthresh int
-
- // SndCAAckCount is the number of packets consumed in congestion
- // avoidance mode.
- SndCAAckCount int
-
- // Outstanding is the number of packets in flight.
- Outstanding int
-
- // SackedOut is the number of packets which have been selectively acked.
- SackedOut int
-
- // SndWnd is the send window size in bytes.
- SndWnd seqnum.Size
-
- // SndUna is the next unacknowledged sequence number.
- SndUna seqnum.Value
-
- // SndNxt is the sequence number of the next segment to be sent.
- SndNxt seqnum.Value
-
- // RTTMeasureSeqNum is the sequence number being used for the latest RTT
- // measurement.
- RTTMeasureSeqNum seqnum.Value
-
- // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
- RTTMeasureTime time.Time
-
- // Closed indicates that the caller has closed the endpoint for sending.
- Closed bool
-
- // SRTT is the smoothed round-trip time as defined in section 2 of
- // RFC 6298.
- SRTT time.Duration
-
- // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
- RTO time.Duration
-
- // RTTVar is the round-trip time variation as defined in section 2 of
- // RFC 6298.
- RTTVar time.Duration
-
- // SRTTInited if true indicates take a valid RTT measurement has been
- // completed.
- SRTTInited bool
-
- // MaxPayloadSize is the maximum size of the payload of a given segment.
- // It is initialized on demand.
- MaxPayloadSize int
-
- // SndWndScale is the number of bits to shift left when reading the send
- // window size from a segment.
- SndWndScale uint8
-
- // MaxSentAck is the highest acknowledgement number sent till now.
- MaxSentAck seqnum.Value
-
- // FastRecovery holds the fast recovery state for the endpoint.
- FastRecovery TCPFastRecoveryState
-
- // Cubic holds the state related to CUBIC congestion control.
- Cubic TCPCubicState
-
- // RACKState holds the state related to RACK loss detection algorithm.
- RACKState TCPRACKState
-}
-
-// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
-type TCPSACKInfo struct {
- // Blocks is the list of SACK Blocks that identify the out of order segments
- // held by a given TCP endpoint.
- Blocks []header.SACKBlock
-
- // ReceivedBlocks are the SACK blocks received by this endpoint
- // from the peer endpoint.
- ReceivedBlocks []header.SACKBlock
-
- // MaxSACKED is the highest sequence number that has been SACKED
- // by the peer.
- MaxSACKED seqnum.Value
-}
-
-// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
-type RcvBufAutoTuneParams struct {
- // MeasureTime is the time at which the current measurement
- // was started.
- MeasureTime time.Time
-
- // CopiedBytes is the number of bytes copied to user space since
- // this measure began.
- CopiedBytes int
-
- // PrevCopiedBytes is the number of bytes copied to userspace in
- // the previous RTT period.
- PrevCopiedBytes int
-
- // RcvBufSize is the auto tuned receive buffer size.
- RcvBufSize int
-
- // RTT is the smoothed RTT as measured by observing the time between
- // when a byte is first acknowledged and the receipt of data that is at
- // least one window beyond the sequence number that was acknowledged.
- RTT time.Duration
-
- // RTTVar is the "round-trip time variation" as defined in section 2
- // of RFC6298.
- RTTVar time.Duration
-
- // RTTMeasureSeqNumber is the highest acceptable sequence number at the
- // time this RTT measurement period began.
- RTTMeasureSeqNumber seqnum.Value
-
- // RTTMeasureTime is the absolute time at which the current RTT
- // measurement period began.
- RTTMeasureTime time.Time
-
- // Disabled is true if an explicit receive buffer is set for the
- // endpoint.
- Disabled bool
-}
-
-// TCPEndpointState is a copy of the internal state of a TCP endpoint.
-type TCPEndpointState struct {
- // ID is a copy of the TransportEndpointID for the endpoint.
- ID TCPEndpointID
-
- // SegTime denotes the absolute time when this segment was received.
- SegTime time.Time
-
- // RcvBufSize is the size of the receive socket buffer for the endpoint.
- RcvBufSize int
-
- // RcvBufUsed is the amount of bytes actually held in the receive socket
- // buffer for the endpoint.
- RcvBufUsed int
-
- // RcvBufAutoTuneParams is used to hold state variables to compute
- // the auto tuned receive buffer size.
- RcvAutoParams RcvBufAutoTuneParams
-
- // RcvClosed if true, indicates the endpoint has been closed for reading.
- RcvClosed bool
-
- // SendTSOk is used to indicate when the TS Option has been negotiated.
- // When sendTSOk is true every non-RST segment should carry a TS as per
- // RFC7323#section-1.1.
- SendTSOk bool
-
- // RecentTS is the timestamp that should be sent in the TSEcr field of
- // the timestamp for future segments sent by the endpoint. This field is
- // updated if required when a new segment is received by this endpoint.
- RecentTS uint32
-
- // TSOffset is a randomized offset added to the value of the TSVal field
- // in the timestamp option.
- TSOffset uint32
-
- // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
- // option in the SYN/SYN-ACK.
- SACKPermitted bool
-
- // SACK holds TCP SACK related information for this endpoint.
- SACK TCPSACKInfo
-
- // SndBufSize is the size of the socket send buffer.
- SndBufSize int
-
- // SndBufUsed is the number of bytes held in the socket send buffer.
- SndBufUsed int
-
- // SndClosed indicates that the endpoint has been closed for sends.
- SndClosed bool
-
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
- // PacketTooBigCount is used to notify the main protocol routine how
- // many times a "packet too big" control packet is received.
- PacketTooBigCount int
-
- // SndMTU is the smallest MTU seen in the control packets received.
- SndMTU int
-
- // Receiver holds variables related to the TCP receiver for the endpoint.
- Receiver TCPReceiverState
-
- // Sender holds state related to the TCP Sender for the endpoint.
- Sender TCPSenderState
-}
-
// ResumableEndpoint is an endpoint that needs to be resumed after restore.
type ResumableEndpoint interface {
// Resume resumes an endpoint after restore. This can be used to restart
@@ -455,7 +154,7 @@ type Stack struct {
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
- receiveBufferSize ReceiveBufferSizeOption
+ receiveBufferSize tcpip.ReceiveBufferSizeOption
// tcpInvalidRateLimit is the maximal rate for sending duplicate
// acknowledgements in response to incoming TCP packets that are for an existing
@@ -623,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
- clock = &tcpip.StdClock{}
+ clock = tcpip.NewStdClock()
}
if opts.UniqueID == nil {
@@ -669,7 +368,7 @@ func New(opts Options) *Stack {
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
},
- receiveBufferSize: ReceiveBufferSizeOption{
+ receiveBufferSize: tcpip.ReceiveBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
+ isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
@@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1874,7 +1573,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
- return nic.WritePacketToRemote(remote, nil, netProto, pkt)
+ return nic.WritePacketToRemote(remote, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index dfec4258a..33824afd0 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -14,6 +14,78 @@
package stack
+import "time"
+
// StackFromEnv is the global stack created in restore run.
// FIXME(b/36201077)
var StackFromEnv *Stack
+
+// saveT is invoked by stateify.
+func (t *TCPCubicState) saveT() unixTime {
+ return unixTime{t.T.Unix(), t.T.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (t *TCPCubicState) loadT(unix unixTime) {
+ t.T = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (t *TCPRACKState) saveXmitTime() unixTime {
+ return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (t *TCPRACKState) loadXmitTime(unix unixTime) {
+ t.XmitTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveLastSendTime is invoked by stateify.
+func (t *TCPSenderState) saveLastSendTime() unixTime {
+ return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (t *TCPSenderState) loadLastSendTime(unix unixTime) {
+ t.LastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) saveRTTMeasureTime() unixTime {
+ return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) {
+ t.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime {
+ return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()}
+}
+
+// loadMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
+ r.MeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime {
+ return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) {
+ r.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveSegTime is invoked by stateify.
+func (t *TCPEndpointState) saveSegTime() unixTime {
+ return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()}
+}
+
+// loadSegTime is invoked by stateify.
+func (t *TCPEndpointState) loadSegTime(unix unixTime) {
+ t.SegTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 3066f4ffd..80e8e0089 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error {
s.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
@@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error {
s.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.ReceiveBufferSizeOption:
s.mu.RLock()
*v = s.receiveBufferSize
s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 2814b94b4..d2c40cc43 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -39,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -137,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data().TrimFront(fakeNetHeaderLen)
+ // DeleteFront invalidates slices. Make a copy before trimming.
+ nb := append([]byte(nil), hdr...)
+ pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -170,7 +173,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++
@@ -189,11 +192,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -436,7 +439,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error
}
func send(r *stack.Route, payload buffer.View) tcpip.Error {
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
}))
@@ -1461,7 +1464,7 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if n := ep.Drain(); n != 0 {
t.Fatalf("got ep.Drain() = %d, want = 0", n)
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: fakeTransNumber,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -1645,10 +1648,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
// Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
- nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ nic1Gateway := testutil.MustParse4("192.168.1.1")
// Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
- nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+ nic2Gateway := testutil.MustParse4("10.10.10.1")
// Create a new stack with two NICs.
s := stack.New(stack.Options{
@@ -2789,25 +2792,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
- ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
- toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
nicID = 1
lifetimeSeconds = 9999
)
+ var (
+ linkLocalAddr1 = testutil.MustParse6("fe80::1")
+ linkLocalAddr2 = testutil.MustParse6("fe80::2")
+ linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
+ uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
+ uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
+ globalAddr1 = testutil.MustParse6("a000::1")
+ globalAddr2 = testutil.MustParse6("a000::2")
+ globalAddr3 = testutil.MustParse6("a000::3")
+ ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1")
+ ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2")
+ toredoAddr1 = testutil.MustParse6("2001::1")
+ toredoAddr2 = testutil.MustParse6("2001::2")
+ ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1")
+ ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2")
+ )
+
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
@@ -3017,7 +3022,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -3354,21 +3359,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- rs stack.ReceiveBufferSizeOption
+ rs tcpip.ReceiveBufferSizeOption
err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
- {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
- {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -3377,7 +3382,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
if err := s.SetOption(tc.rs); err != tc.err {
t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if tc.err == nil {
if err := s.Option(&rs); err != nil {
t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
@@ -3448,7 +3453,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
ipv4Subnet := ipv4Addr.Subnet()
ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4Gateway := testutil.MustParse4("192.168.1.1")
ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
Address: "\xc0\xa8\x01\x3a",
PrefixLen: 31,
@@ -4352,13 +4357,15 @@ func TestWritePacketToRemote(t *testing.T) {
func TestClearNeighborCacheOnNICDisable(t *testing.T) {
const (
- nicID = 1
-
- ipv4Addr = tcpip.Address("\x01\x02\x03\x04")
- ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04")
+ nicID = 1
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
)
+ var (
+ ipv4Addr = testutil.MustParse4("1.2.3.4")
+ ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304")
+ )
+
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
new file mode 100644
index 000000000..ddff6e2d6
--- /dev/null
+++ b/pkg/tcpip/stack/tcp.go
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPCubicState struct {
+ // WLastMax is the previous wMax value.
+ WLastMax float64
+
+ // WMax is the value of the congestion window at the time of the last
+ // congestion event.
+ WMax float64
+
+ // T is the time when the current congestion avoidance was entered.
+ T time.Time `state:".(unixTime)"`
+
+ // TimeSinceLastCongestion denotes the time since the current
+ // congestion avoidance was entered.
+ TimeSinceLastCongestion time.Duration
+
+ // C is the cubic constant as specified in RFC8312, page 11.
+ C float64
+
+ // K is the time period (in seconds) that the above function takes to
+ // increase the current window size to WMax if there are no further
+ // congestion events and is calculated using the following equation:
+ //
+ // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5)
+ K float64
+
+ // Beta is the CUBIC multiplication decrease factor. That is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // WC(0)=WMax*beta_cubic.
+ Beta float64
+
+ // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's
+ // calculated using the formula:
+ //
+ // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1)
+ WC float64
+
+ // WEst is the window computed by CUBIC at time
+ // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT).
+ WEst float64
+}
+
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPRACKState struct {
+ // XmitTime is the transmission timestamp of the most recent
+ // acknowledged segment.
+ XmitTime time.Time `state:".(unixTime)"`
+
+ // EndSequence is the ending TCP sequence number of the most recent
+ // acknowledged segment.
+ EndSequence seqnum.Value
+
+ // FACK is the highest selectively or cumulatively acknowledged
+ // sequence.
+ FACK seqnum.Value
+
+ // RTT is the round trip time of the most recently delivered packet on
+ // the connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ RTT time.Duration
+
+ // Reord is true iff reordering has been detected on this connection.
+ Reord bool
+
+ // DSACKSeen is true iff the connection has seen a DSACK.
+ DSACKSeen bool
+
+ // ReoWnd is the reordering window time used for recording packet
+ // transmission times. It is used to defer the moment at which RACK
+ // marks a packet lost.
+ ReoWnd time.Duration
+
+ // ReoWndIncr is the multiplier applied to adjust reorder window.
+ ReoWndIncr uint8
+
+ // ReoWndPersist is the number of loss recoveries before resetting
+ // reorder window.
+ ReoWndPersist int8
+
+ // RTTSeq is the SND.NXT when RTT is updated.
+ RTTSeq seqnum.Value
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+//
+// +stateify savable
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+//
+// +stateify savable
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery. The
+ // following fields are only meaningful when Active is true.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ // It exists to avoid attacks where the receiver intentionally sends
+ // duplicate acks to artificially inflate the sender's cwnd.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase. See: RFC 6675 Section 2 for
+ // details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission. See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for a
+// given TCP endpoint.
+//
+// +stateify savable
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to its
+ // peer that it's willing to accept. This may be different than RcvNxt
+ // + (last advertised receive window) if the receive window is reduced;
+ // in that case we have to reduce the window as we receive more data
+ // instead of shrinking it.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive queue.
+ PendingBufUsed int
+}
+
+// TCPRTTState holds a copy of information about the endpoint's round trip
+// time.
+//
+// +stateify savable
+type TCPRTTState struct {
+ // SRTT is the smoothed round trip time defined in section 2 of RFC
+ // 6298.
+ SRTT time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates that a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for a given
+// TCP Endpoint.
+//
+// +stateify savable
+type TCPSenderState struct {
+ // LastSendTime is the timestamp at which we sent the last segment.
+ LastSendTime time.Time `state:".(unixTime)"`
+
+ // DupAckCount is the number of Duplicate ACKs received. It is used for
+ // fast retransmit.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the threshold between slow start and congestion
+ // avoidance.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets acknowledged during
+ // congestion avoidance. When enough packets have been ack'd (typically
+ // cwnd packets), the congestion window is incremented by one.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets that have been sent but not yet
+ // acknowledged.
+ Outstanding int
+
+ // SackedOut is the number of packets which have been selectively
+ // acked.
+ SackedOut int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest
+ // RTT measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Closed indicates that the caller has closed the endpoint for
+ // sending.
+ Closed bool
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC
+ // 6298.
+ RTO time.Duration
+
+ // RTTState holds information about the endpoint's round trip time.
+ RTTState TCPRTTState
+
+ // MaxPayloadSize is the maximum size of the payload of a given
+ // segment. It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the
+ // send window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+//
+// +stateify savable
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order
+ // segments held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint from
+ // the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED by the
+ // peer.
+ MaxSACKED seqnum.Value
+}
+
+// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
+//
+// +stateify savable
+type RcvBufAutoTuneParams struct {
+ // MeasureTime is the time at which the current measurement was
+ // started.
+ MeasureTime time.Time `state:".(unixTime)"`
+
+ // CopiedBytes is the number of bytes copied to user space since this
+ // measure began.
+ CopiedBytes int
+
+ // PrevCopiedBytes is the number of bytes copied to userspace in the
+ // previous RTT period.
+ PrevCopiedBytes int
+
+ // RcvBufSize is the auto tuned receive buffer size.
+ RcvBufSize int
+
+ // RTT is the smoothed RTT as measured by observing the time between
+ // when a byte is first acknowledged and the receipt of data that is at
+ // least one window beyond the sequence number that was acknowledged.
+ RTT time.Duration
+
+ // RTTVar is the "round-trip time variation" as defined in section 2 of
+ // RFC6298.
+ RTTVar time.Duration
+
+ // RTTMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ RTTMeasureSeqNumber seqnum.Value
+
+ // RTTMeasureTime is the absolute time at which the current RTT
+ // measurement period began.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ Disabled bool
+}
+
+// TCPRcvBufState contains information about the state of an endpoint's receive
+// socket buffer.
+//
+// +stateify savable
+type TCPRcvBufState struct {
+ // RcvBufUsed is the amount of bytes actually held in the receive
+ // socket buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvBufAutoTuneParams is used to hold state variables to compute the
+ // auto tuned receive buffer size.
+ RcvAutoParams RcvBufAutoTuneParams
+
+ // RcvClosed if true, indicates the endpoint has been closed for
+ // reading.
+ RcvClosed bool
+}
+
+// TCPSndBufState contains information about the state of an endpoint's send
+// socket buffer.
+//
+// +stateify savable
+type TCPSndBufState struct {
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+}
+
+// TCPEndpointStateInner contains the members of TCPEndpointState used directly
+// (that is, not within another containing struct) within the endpoint's
+// internal implementation.
+//
+// +stateify savable
+type TCPEndpointStateInner struct {
+ // TSOffset is a randomized offset added to the value of the TSVal
+ // field in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field
+ // is updated if required when a new segment is received by this
+ // endpoint.
+ RecentTS uint32
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+//
+// +stateify savable
+type TCPEndpointState struct {
+ // TCPEndpointStateInner contains the members of TCPEndpointState used
+ // by the endpoint's internal implementation.
+ TCPEndpointStateInner
+
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time `state:".(unixTime)"`
+
+ // RcvBufState contains information about the state of the endpoint's
+ // receive socket buffer.
+ RcvBufState TCPRcvBufState
+
+ // SndBufState contains information about the state of the endpoint's
+ // send socket buffer.
+ SndBufState TCPSndBufState
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // Receiver holds variables related to the TCP receiver for the
+ // endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e188efccb..80ad1a9d4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
return eps
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
+// handlePacket is called by the stack when new packets arrive to this transport
+// endpoint. It returns false if the packet could not be matched to any
+// transport endpoint, true otherwise.
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return false
}
}
@@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
- return
+ return true
}
transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return true
}
// handleError delivers an error to the transport endpoint identified by id.
@@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber,
}
return false
}
- ep.handlePacket(id, pkt)
- return true
+ return ep.handlePacket(id, pkt)
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 054cced0c..839178809 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
- ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return ep
}
@@ -106,7 +106,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
Data: buffer.View(v).ToVectorisedView(),
})
_ = pkt.TransportHeader().Push(fakeTransHeaderLen)
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
+ if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
return 0, err
}
@@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress(),
route: route,
}
- ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}
diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go
new file mode 100644
index 000000000..7ce43a68e
--- /dev/null
+++ b/pkg/tcpip/stdclock.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// stdClock implements Clock with the time package.
+//
+// +stateify savable
+type stdClock struct {
+ // baseTime holds the time when the clock was constructed.
+ //
+ // This value is used to calculate the monotonic time from the time package.
+ // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks,
+ //
+ // Operating systems provide both a “wall clock,” which is subject to
+ // changes for clock synchronization, and a “monotonic clock,” which is not.
+ // The general rule is that the wall clock is for telling time and the
+ // monotonic clock is for measuring time. Rather than split the API, in this
+ // package the Time returned by time.Now contains both a wall clock reading
+ // and a monotonic clock reading; later time-telling operations use the wall
+ // clock reading, but later time-measuring operations, specifically
+ // comparisons and subtractions, use the monotonic clock reading.
+ //
+ // ...
+ //
+ // If Times t and u both contain monotonic clock readings, the operations
+ // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using
+ // the monotonic clock readings alone, ignoring the wall clock readings. If
+ // either t or u contains no monotonic clock reading, these operations fall
+ // back to using the wall clock readings.
+ //
+ // Given the above, we can safely conclude that time.Since(baseTime) will
+ // return monotonically increasing values if we use time.Now() to set baseTime
+ // at the time of clock construction.
+ //
+ // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per
+ // https://golang.org/pkg/time/#Since.
+ baseTime time.Time `state:"nosave"`
+
+ // monotonicOffset is the offset applied to the calculated monotonic time.
+ //
+ // monotonicOffset is assigned maxMonotonic after restore so that the
+ // monotonic time will continue from where it "left off" before saving as part
+ // of S/R.
+ monotonicOffset int64 `state:"nosave"`
+
+ // monotonicMU protects maxMonotonic.
+ monotonicMU sync.Mutex `state:"nosave"`
+ maxMonotonic int64
+}
+
+// NewStdClock returns an instance of a clock that uses the time package.
+func NewStdClock() Clock {
+ return &stdClock{
+ baseTime: time.Now(),
+ }
+}
+
+var _ Clock = (*stdClock)(nil)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*stdClock) NowNanoseconds() int64 {
+ return time.Now().UnixNano()
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (s *stdClock) NowMonotonic() int64 {
+ sinceBase := time.Since(s.baseTime)
+ if sinceBase < 0 {
+ panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime))
+ }
+
+ monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset
+
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+
+ // Monotonic time values must never decrease.
+ if monotonicValue > s.maxMonotonic {
+ s.maxMonotonic = monotonicValue
+ }
+
+ return s.maxMonotonic
+}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*stdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/stdclock_state.go
index d0f58cfaf..795db9181 100644
--- a/pkg/tcpip/transport/tcp/cubic_state.go
+++ b/pkg/tcpip/stdclock_state.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+package tcpip
-import (
- "time"
-)
+import "time"
-// saveT is invoked by stateify.
-func (c *cubicState) saveT() unixTime {
- return unixTime{c.t.Unix(), c.t.UnixNano()}
-}
+// afterLoad is invoked by stateify.
+func (s *stdClock) afterLoad() {
+ s.baseTime = time.Now()
-// loadT is invoked by stateify.
-func (c *cubicState) loadT(unix unixTime) {
- c.t = time.Unix(unix.second, unix.nano)
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+ s.monotonicOffset = s.maxMonotonic
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 87ea09a5e..d5f941c5f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -73,7 +73,7 @@ type Clock interface {
// nanoseconds since the Unix epoch.
NowNanoseconds() int64
- // NowMonotonic returns a monotonic time value.
+ // NowMonotonic returns a monotonic time value at nanosecond resolution.
NowMonotonic() int64
// AfterFunc waits for the duration to elapse and then calls f in its own
@@ -691,10 +691,6 @@ const (
// number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption
- // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
- // specify the receive buffer size option.
- ReceiveBufferSizeOption
-
// SendQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
@@ -786,6 +782,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {}
func (*TCPRecovery) isSettableTransportProtocolOption() {}
+// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies.
+type TCPAlwaysUseSynCookies bool
+
+func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {}
+
+func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {}
+
const (
// TCPRACKLossDetection indicates RACK is used for loss detection and
// recovery.
@@ -1020,19 +1023,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {}
func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {}
-// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
-// the number of endpoints that can be in SYN-RCVD state before the stack
-// switches to using SYN cookies.
-type TCPSynRcvdCountThresholdOption uint64
-
-func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {}
-
// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
// default for number of times SYN is retransmitted before aborting a connect.
type TCPSynRetriesOption uint8
@@ -1117,6 +1107,7 @@ const (
// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
// duration for which a socket lingers before returning from Close.
//
+// +marshal
// +stateify savable
type LingerOption struct {
Enabled bool
@@ -1150,6 +1141,19 @@ type SendBufferSizeOption struct {
Max int
}
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ // Min is the minimum size for send buffer.
+ Min int
+
+ // Default is the default size for send buffer.
+ Default int
+
+ // Max is the maximum size for send buffer.
+ Max int
+}
+
// GetSendBufferLimits is used to get the send buffer size limits.
type GetSendBufferLimits func(StackHandler) SendBufferSizeOption
@@ -1162,6 +1166,18 @@ func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption {
return ss
}
+// GetReceiveBufferLimits is used to get the send buffer size limits.
+type GetReceiveBufferLimits func(StackHandler) ReceiveBufferSizeOption
+
+// GetStackReceiveBufferLimits is used to get default, min and max send buffer size.
+func GetStackReceiveBufferLimits(so StackHandler) ReceiveBufferSizeOption {
+ var ss ReceiveBufferSizeOption
+ if err := so.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ return ss
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -1218,7 +1234,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value() uint64 {
+func (s *StatCounter) Value(name ...string) uint64 {
return atomic.LoadUint64(&s.count)
}
@@ -1512,6 +1528,30 @@ type IGMPStats struct {
// LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats)
}
+// IPForwardingStats collects stats related to IP forwarding (both v4 and v6).
+type IPForwardingStats struct {
+ // Unrouteable is the number of IP packets received which were dropped
+ // because the netstack could not construct a route to their
+ // destination.
+ Unrouteable *StatCounter
+
+ // ExhaustedTTL is the number of IP packets received which were dropped
+ // because their TTL was exhausted.
+ ExhaustedTTL *StatCounter
+
+ // LinkLocalSource is the number of IP packets which were dropped
+ // because they contained a link-local source address.
+ LinkLocalSource *StatCounter
+
+ // LinkLocalDestination is the number of IP packets which were dropped
+ // because they contained a link-local destination address.
+ LinkLocalDestination *StatCounter
+
+ // Errors is the number of IP packets received which could not be
+ // successfully forwarded.
+ Errors *StatCounter
+}
+
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// LINT.IfChange(IPStats)
@@ -1562,6 +1602,10 @@ type IPStats struct {
// chain.
IPTablesOutputDropped *StatCounter
+ // IPTablesPostroutingDropped is the number of IP packets dropped in the
+ // Postrouting chain.
+ IPTablesPostroutingDropped *StatCounter
+
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
@@ -1576,6 +1620,9 @@ type IPStats struct {
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived *StatCounter
+ // Forwarding collects stats related to IP forwarding.
+ Forwarding IPForwardingStats
+
// LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats)
}
@@ -1734,6 +1781,10 @@ type TCPStats struct {
// ChecksumErrors is the number of segments dropped due to bad checksums.
ChecksumErrors *StatCounter
+
+ // FailedPortReservations is the number of times TCP failed to reserve
+ // a port.
+ FailedPortReservations *StatCounter
}
// UDPStats collects UDP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 3cc8c36f1..d4f7bb5ff 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -9,11 +9,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -78,6 +81,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
@@ -101,6 +105,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -123,6 +128,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index d10ae05c2..dbd279c94 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -21,11 +21,14 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -312,3 +315,194 @@ func TestForwarding(t *testing.T) {
})
}
}
+
+func TestMulticastForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ ttl = 64
+ )
+
+ var (
+ ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
+ ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+
+ ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
+ ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+ )
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+ }
+
+ v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+ }
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ expectForward bool
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 link-local multicast destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local source",
+ srcAddr: ipv4LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv4Addr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalUnicastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 non-link-local unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 non-link-local multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 link-local multicast destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local source",
+ srcAddr: ipv6LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv6Addr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalUnicastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 non-link-local unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 non-link-local multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ p, ok := e2.Read()
+ if ok != test.expectForward {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
+ }
+
+ if test.expectForward {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 1cfd854a0..c61d4e788 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -347,7 +347,7 @@ type channelEndpointWithoutWritePacket struct {
t *testing.T
}
-func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
return &tcpip.ErrNotSupported{}
}
@@ -627,7 +627,7 @@ func TestIPTableWritePackets(t *testing.T) {
pkts := test.genPacket(r)
pktsLen := pkts.Len()
- if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: 64,
}); err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index d39809e1c..c657714ba 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -687,10 +687,10 @@ func TestWritePacketsLinkResolution(t *testing.T) {
TOS: stack.DefaultTOS,
}
- if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
- t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
+ if n, err := r.WritePackets(pkts, params); err != nil {
+ t.Fatalf("r.WritePackets(_, %#v): %s", params, err)
} else if want := pkts.Len(); want != n {
- t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
+ t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want)
}
var writer bytes.Buffer
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 2c538a43e..3df1bbd68 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -314,11 +315,11 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
TOS: stack.DefaultTOS,
}
data := buffer.View([]byte{1, 2, 3, 4})
- if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: data.ToVectorisedView(),
})); err != nil {
- t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err)
+ t.Fatalf("r.WritePacket(%#v, _): %s", params, err)
}
// Removing the address should make the endpoint invalid.
@@ -326,12 +327,12 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
}
{
- err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: data.ToVectorisedView(),
}))
if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
- t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
+ t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
}
}
}
@@ -510,25 +511,25 @@ func TestExternalLoopbackTraffic(t *testing.T) {
nicID1 = 1
nicID2 = 2
- ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
-
numPackets = 1
+ ttl = 64
)
+ ipv4Loopback := testutil.MustParse4("127.0.0.1")
loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl)
}
loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl)
}
loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl)
}
loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl)
}
invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index c6a9c2393..2d0a6e6a7 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -43,12 +44,15 @@ const (
// to a multicast or broadcast address uses a unicast source address for the
// reply.
func TestPingMulticastBroadcast(t *testing.T) {
- const nicID = 1
+ const (
+ nicID = 1
+ ttl = 64
+ )
tests := []struct {
name string
protoNum tcpip.NetworkProtocolNumber
- rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
srcAddr tcpip.Address
dstAddr tcpip.Address
expectedSrc tcpip.Address
@@ -136,7 +140,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- test.rxICMP(e, test.srcAddr, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
@@ -435,10 +439,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
// interested endpoints.
func TestReuseAddrAndBroadcast(t *testing.T) {
const (
- nicID = 1
- localPort = 9000
- loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ nicID = 1
+ localPort = 9000
)
+ loopbackBroadcast := testutil.MustParse4("127.255.255.255")
tests := []struct {
name string
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 78244f4eb..ac3c703d4 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -40,13 +41,13 @@ import (
// This tests that a local route is created and packets do not leave the stack.
func TestLocalPing(t *testing.T) {
const (
- nicID = 1
- ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+ nicID = 1
// icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
// request/reply packets.
icmpDataOffset = 8
)
+ ipv4Loopback := testutil.MustParse4("127.0.0.1")
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index d1c9f3a94..8fd9be32b 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -48,10 +48,6 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
-const (
- ttl = 255
-)
-
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
new file mode 100644
index 000000000..472545a5d
--- /dev/null
+++ b/pkg/tcpip/testutil/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = True,
+ srcs = ["testutil.go"],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
+
+go_test(
+ name = "testutil_test",
+ srcs = ["testutil_test.go"],
+ library = ":testutil",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
new file mode 100644
index 000000000..1aaed590f
--- /dev/null
+++ b/pkg/tcpip/testutil/testutil.go
@@ -0,0 +1,43 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil provides helper functions for netstack unit tests.
+package testutil
+
+import (
+ "fmt"
+ "net"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address.
+// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes.
+func MustParse4(addr string) tcpip.Address {
+ ip := net.ParseIP(addr).To4()
+ if ip == nil {
+ panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr))
+ }
+ return tcpip.Address(ip)
+}
+
+// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing
+// an IPv4 address will yield an IPv4-mapped IPv6 address.
+func MustParse6(addr string) tcpip.Address {
+ ip := net.ParseIP(addr).To16()
+ if ip == nil {
+ panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr))
+ }
+ return tcpip.Address(ip)
+}
diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go
new file mode 100644
index 000000000..6aad9585d
--- /dev/null
+++ b/pkg/tcpip/testutil/testutil_test.go
@@ -0,0 +1,103 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// Who tests the testutils?
+
+func TestMustParse4(t *testing.T) {
+ tcs := []struct {
+ str string
+ addr tcpip.Address
+ shouldPanic bool
+ }{
+ {
+ str: "127.0.0.1",
+ addr: "\x7f\x00\x00\x01",
+ }, {
+ str: "",
+ shouldPanic: true,
+ }, {
+ str: "fe80::1",
+ shouldPanic: true,
+ }, {
+ // In an ideal world this panics too, but net.IP
+ // doesn't distinguish between IPv4 and IPv4-mapped
+ // addresses.
+ str: "::ffff:0.0.0.1",
+ addr: "\x00\x00\x00\x01",
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.str, func(t *testing.T) {
+ if tc.shouldPanic {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("panic expected, but did not occur")
+ }
+ }()
+ }
+ if got := MustParse4(tc.str); got != tc.addr {
+ t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr)
+ }
+ })
+ }
+}
+
+func TestMustParse6(t *testing.T) {
+ tcs := []struct {
+ str string
+ addr tcpip.Address
+ shouldPanic bool
+ }{
+ {
+ // In an ideal world this panics too, but net.IP
+ // doesn't distinguish between IPv4 and IPv4-mapped
+ // addresses.
+ str: "127.0.0.1",
+ addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01",
+ }, {
+ str: "",
+ shouldPanic: true,
+ }, {
+ str: "fe80::1",
+ addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }, {
+ str: "::ffff:0.0.0.1",
+ addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01",
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.str, func(t *testing.T) {
+ if tc.shouldPanic {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("panic expected, but did not occur")
+ }
+ }()
+ }
+ if got := MustParse6(tc.str); got != tc.addr {
+ t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
deleted file mode 100644
index eeea97b12..000000000
--- a/pkg/tcpip/time_unsafe.go
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build go1.9
-// +build !go1.18
-
-// Check go:linkname function signatures when updating Go version.
-
-package tcpip
-
-import (
- "time" // Used with go:linkname.
- _ "unsafe" // Required for go:linkname.
-)
-
-// StdClock implements Clock with the time package.
-//
-// +stateify savable
-type StdClock struct{}
-
-var _ Clock = (*StdClock)(nil)
-
-//go:linkname now time.now
-func now() (sec int64, nsec int32, mono int64)
-
-// NowNanoseconds implements Clock.NowNanoseconds.
-func (*StdClock) NowNanoseconds() int64 {
- sec, nsec, _ := now()
- return sec*1e9 + int64(nsec)
-}
-
-// NowMonotonic implements Clock.NowMonotonic.
-func (*StdClock) NowMonotonic() int64 {
- _, _, mono := now()
- return mono
-}
-
-// AfterFunc implements Clock.AfterFunc.
-func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
- return &stdTimer{
- t: time.AfterFunc(d, f),
- }
-}
-
-type stdTimer struct {
- t *time.Timer
-}
-
-var _ Timer = (*stdTimer)(nil)
-
-// Stop implements Timer.Stop.
-func (st *stdTimer) Stop() bool {
- return st.t.Stop()
-}
-
-// Reset implements Timer.Reset.
-func (st *stdTimer) Reset(d time.Duration) {
- st.t.Reset(d)
-}
-
-// NewStdTimer returns a Timer implemented with the time package.
-func NewStdTimer(t *time.Timer) Timer {
- return &stdTimer{t: t}
-}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index a82384c49..1633d0aeb 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -29,7 +29,7 @@ const (
)
func TestJobReschedule(t *testing.T) {
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
wg.Done()
})
job.Schedule(shortDuration)
@@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) {
func TestJobExecution(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
ch <- struct{}{}
})
job.Schedule(shortDuration)
@@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(middleDuration)
lock.Lock()
@@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) {
func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
@@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
job.Cancel()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 50991c3c0..8afde7fca 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -63,12 +63,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList icmpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -84,6 +83,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
@@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- state: stateInitial,
- uniqueID: s.UniqueID(),
+ waiterQueue: waiterQueue,
+ state: stateInitial,
+ uniqueID: s.UniqueID(),
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
+ var rs tcpip.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
+ }
return ep, nil
}
@@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.rcvMu.Lock()
v := int(e.ttl)
@@ -430,7 +431,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
return err
}
@@ -477,7 +478,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
}
@@ -746,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -755,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index a3c6db5a8..28a56a2d5 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if e.state != stateBound && e.state != stateConnected {
return
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 52ed9560c..496eca581 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -72,11 +72,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -91,6 +90,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a new packet endpoint.
@@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
},
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- ep.rcvBufSizeMax = rs.Default
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
@@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := ep.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- ep.rcvMu.Lock()
- ep.rcvBufSizeMax = v
- ep.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
func (ep *endpoint) LastError() tcpip.Error {
@@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
ep.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- v := ep.rcvBufSizeMax
- ep.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
return
}
- if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ rcvBufSize := ep.ops.GetReceiveBufferSize()
+ if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (ep *endpoint) freeze() {
+ ep.mu.Lock()
+ ep.frozen = true
+ ep.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (ep *endpoint) thaw() {
+ ep.mu.Lock()
+ ep.frozen = false
+ ep.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index ece662c0d..5bd860d20 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- ep.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- ep.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+ ep.freeze()
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
+ ep.thaw()
ep.stack = stack.StackFromEnv
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e27a249cd..bcec3d2e7 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,7 +26,6 @@
package raw
import (
- "fmt"
"io"
"gvisor.dev/gvisor/pkg/sync"
@@ -69,11 +68,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList rawPacketList
- rcvBufSize int
- rcvBufSizeMax int `state:".(int)"`
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList rawPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -89,6 +87,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- associated: associated,
+ waiterQueue: waiterQueue,
+ associated: associated,
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
// Unassociated endpoints are write-only and users call Write() with IP
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- e.rcvBufSizeMax = 0
+ e.ops.SetReceiveBufferSize(0, false)
e.waiterQueue = nil
return e, nil
}
@@ -352,7 +354,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
pkt.Owner = owner
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := route.WritePacket(stack.NetworkHeaderParams{
Protocol: e.TransProto,
TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
@@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- e.rcvMu.Lock()
- e.rcvBufSizeMax = v
- e.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
@@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 263ec5146..5d6f2709c 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// If the endpoint is connected, re-connect.
if e.connected {
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index a69d6624d..48417f192 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -34,14 +34,12 @@ go_library(
"connect.go",
"connect_unsafe.go",
"cubic.go",
- "cubic_state.go",
"dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
"protocol.go",
"rack.go",
- "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -107,6 +105,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 025b134e2..d4bd4e80e 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -23,7 +23,6 @@ import (
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -51,11 +50,6 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-
- // SynRcvdCountThreshold is the default global maximum number of
- // connections that are allowed to be in SYN-RCVD state before TCP
- // starts using SYN cookies to accept connections.
- SynRcvdCountThreshold uint64 = 1000
)
var (
@@ -80,9 +74,6 @@ func encodeMSS(mss uint16) uint32 {
type listenContext struct {
stack *stack.Stack
- // synRcvdCount is a reference to the stack level synRcvdCount.
- synRcvdCount *synRcvdCounter
-
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
rcvWnd seqnum.Size
@@ -138,14 +129,12 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
- p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
- if !ok {
- panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
- }
- l.synRcvdCount = p.SynRcvdCounter()
- rand.Read(l.nonce[0][:])
- rand.Read(l.nonce[1][:])
+ for i := range l.nonce {
+ if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil {
+ panic(err)
+ }
+ }
return l
}
@@ -163,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
+
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
- io.WriteString(l.hasher, string(id.LocalAddress))
- io.WriteString(l.hasher, string(id.RemoteAddress))
+ l.hasher.Write([]byte(id.LocalAddress))
+ l.hasher.Write([]byte(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
- h := make([]byte, 0, sha1.Size)
- h = l.hasher.Sum(h)
+ h := l.hasher.Sum(nil)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
@@ -199,9 +191,17 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
+func (l *listenContext) useSynCookies() bool {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
+}
+
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n := newEndpoint(l.stack, netProto, queue)
n.ops.SetV6Only(l.v6Only)
- n.ID = s.id
+ n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
n.route = route
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
- n.rcvBufSize = int(l.rcvWnd)
+ n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
- n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+ n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow()
return n, nil
}
@@ -248,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
}
@@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
}
// Register new endpoint so that packets are routed to it.
- if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ if err := ep.stack.RegisterTransportEndpoint(
+ ep.effectiveNetProtos,
+ ProtocolNumber,
+ ep.TransportEndpointInfo.ID,
+ ep,
+ ep.boundPortFlags,
+ ep.boundBindToDevice,
+ ); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -307,6 +314,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Initialize and start the handshake.
h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ h.listenEP = l.listenEP
h.start()
return h, nil
}
@@ -334,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- l.pendingEndpoints[n.ID] = n
+ l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.ID)
+ delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
l.pending.Done()
l.pendingMu.Unlock()
}
@@ -382,39 +390,46 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
// scaling.
- e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+ e.rcv.RcvWndScale = e.h.effectiveRcvWndScale()
// Clean up handshake state stored in the endpoint so that it can be GCed.
e.h = nil
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state (acceptedChan is nil),
-// the new endpoint is closed instead.
+// listener has transitioned out of the listen state (accepted is the zero
+// value), the new endpoint is reset instead.
func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
defer e.pendingAccepted.Done()
- e.acceptMu.Lock()
- for {
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
- n.notifyProtocolGoroutine(notifyReset)
- return
- }
- select {
- case e.acceptedChan <- n:
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ if e.accepted == (accepted{}) {
+ return false
+ }
+ if e.accepted.endpoints.Len() == e.accepted.cap {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.accepted.endpoints.PushBack(n)
if !withSynCookie {
atomic.AddInt32(&e.synRcvdCount, -1)
}
- e.acceptMu.Unlock()
- e.waiterQueue.Notify(waiter.ReadableEvents)
- return
- default:
- e.acceptCond.Wait()
+ return true
}
+ }()
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ n.notifyProtocolGoroutine(notifyReset)
}
}
@@ -436,17 +451,21 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// * propagateInheritableOptionsLocked has been called.
// * e.mu is held.
func (e *endpoint) reserveTupleLocked() bool {
- dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ dest := tcpip.FullAddress{
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
+ }
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: dest,
}
if !e.stack.ReserveTuple(portRes) {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return false
}
@@ -485,7 +504,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- defer ctx.synRcvdCount.dec()
if err := h.complete(); err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -497,24 +515,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }() // S/R-SAFE: synRcvdCount is the barrier.
+ }()
return nil
}
-func (e *endpoint) incSynRcvdCount() bool {
+func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
+ acceptedCap := e.accepted.cap
e.acceptMu.Unlock()
- if canInc {
- atomic.AddInt32(&e.synRcvdCount, 1)
- }
- return canInc
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ //
+ // We maintain an equality check here as the synRcvdCount is incremented
+ // and compared only from a single listener context and the capacity of
+ // the accepted queue can only increase by a new listen call.
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
+ full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
e.acceptMu.Unlock()
return full
}
@@ -524,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool {
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
- e.rcvListMu.Lock()
- rcvClosed := e.rcvClosed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := e.rcvQueueInfo.RcvClosed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
// If the endpoint is shutdown, reply with reset.
//
@@ -538,69 +561,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
switch {
case s.flags == header.TCPFlagSyn:
- opts := parseSynSegmentOptions(s)
- if ctx.synRcvdCount.inc() {
- // Only handle the syn if the following conditions hold
- // - accept queue is not full.
- // - number of connections in synRcvd state is less than the
- // backlog.
- if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
- s.incRef()
- _ = e.handleSynSegment(ctx, s, &opts)
- return nil
- }
- ctx.synRcvdCount.dec()
+ if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return nil
- } else {
- // If cookies are in use but the endpoint accept queue
- // is full then drop the syn.
- if e.acceptQueueIsFull() {
- e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
- cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ }
- route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer route.Release()
+ opts := parseSynSegmentOptions(s)
+ if !ctx.useSynCookies() {
+ s.incRef()
+ atomic.AddInt32(&e.synRcvdCount, 1)
+ return e.handleSynSegment(ctx, s, &opts)
+ }
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
- // Send SYN without window scaling because we currently
- // don't encode this information in the cookie.
- //
- // Enable Timestamp option if the original syn did have
- // the timestamp option specified.
- //
- // Use the user supplied MSS on the listening socket for
- // new connections, if available.
- synOpts := header.TCPSynOptions{
- WS: -1,
- TS: opts.TS,
- TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
- TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, route),
- }
- fields := tcpFields{
- id: s.id,
- ttl: e.ttl,
- tos: e.sendTOS,
- flags: header.TCPFlagSyn | header.TCPFlagAck,
- seq: cookie,
- ack: s.sequenceNumber + 1,
- rcvWnd: ctx.rcvWnd,
- }
- if err := e.sendSynTCP(route, fields, synOpts); err != nil {
- return err
- }
- e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
- return nil
+ // Send SYN without window scaling because we currently
+ // don't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
+ TSEcr: opts.TSVal,
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
+ }
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ fields := tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
+ return err
}
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
@@ -615,25 +624,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- if !ctx.synRcvdCount.synCookiesInUse() {
- // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
- // Any acknowledgment is bad if it arrives on a connection still in
- // the LISTEN state. An acceptable reset segment should be formed
- // for any arriving ACK-bearing segment. The RST should be
- // formatted as follows:
- //
- // <SEQ=SEG.ACK><CTL=RST>
- //
- // Send a reset as this is an ACK for which there is no
- // half open connections and we are not using cookies
- // yet.
- //
- // The only time we should reach here when a connection
- // was opened and closed really quickly and a delayed
- // ACK was received from the sender.
- return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -651,7 +641,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return nil
+
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -672,7 +678,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
return err
}
@@ -693,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(
+ n.effectiveNetProtos,
+ ProtocolNumber,
+ n.TransportEndpointInfo.ID,
+ n,
+ n.boundPortFlags,
+ n.boundBindToDevice,
+ ); err != nil {
n.mu.Unlock()
n.Close()
@@ -708,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
// sent above.
- n.tsOffset = 0
+ n.TSOffset = 0
// Switch state to connected.
n.isConnectNotified = true
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a9e978cf6..5e03e7715 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -65,11 +65,12 @@ const (
// NOTE: handshake.ep.mu is held during handshake processing. It is released if
// we are going to block and reacquired when we start processing an event.
type handshake struct {
- ep *endpoint
- state handshakeState
- active bool
- flags header.TCPFlags
- ackNum seqnum.Value
+ ep *endpoint
+ listenEP *endpoint
+ state handshakeState
+ active bool
+ flags header.TCPFlags
+ ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
iss seqnum.Value
@@ -155,7 +156,7 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed())
}
// generateSecureISN generates a secure Initial Sequence number based on the
@@ -301,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
ttl = h.ep.route.DefaultTTL()
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -357,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
- TS: h.ep.sendTSOk,
+ TS: h.ep.SendTSOk,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
- SACKPermitted: h.ep.sackPermitted,
+ SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -389,13 +390,22 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
- if h.ep.sendTSOk && !s.parsedOptions.TS {
+ if h.ep.SendTSOk && !s.parsedOptions.TS {
h.ep.stack.Stats().DroppedPackets.Increment()
return nil
}
+ // Drop the ACK if the accept queue is full.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523
+ // We could abort the connection as well with a tunable as in
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788
+ if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() {
+ listenEP.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// Update timestamp if required. See RFC7323, section-4.3.
- if h.ep.sendTSOk && s.parsedOptions.TS {
+ if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
@@ -485,8 +495,8 @@ func (h *handshake) start() {
// start() is also called in a listen context so we want to make sure we only
// send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
- synOpts.TS = h.ep.sendTSOk
- synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ synOpts.TS = h.ep.SendTSOk
+ synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled)
if h.sndWndScale < 0 {
// Disable window scaling if the peer did not send us
// the window scaling option.
@@ -496,7 +506,7 @@ func (h *handshake) start() {
h.sendSYNOpts = synOpts
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -544,7 +554,7 @@ func (h *handshake) complete() tcpip.Error {
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -576,8 +586,14 @@ func (h *handshake) complete() tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ // Check for any ICMP errors notified to us.
if n&notifyError != 0 {
- return h.ep.lastErrorLocked()
+ if err := h.ep.lastErrorLocked(); err != nil {
+ return err
+ }
+ // Flag the handshake failure as aborted if the lastError is
+ // cleared because of a socket layer call.
+ return &tcpip.ErrConnectionAborted{}
}
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -711,14 +727,14 @@ type tcpFields struct {
func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error {
tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
- if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
+ if err := e.sendTCP(r, tf, buffer.VectorisedView{}, stack.GSO{}); err != nil {
e.stats.SendErrors.SynSendToNetworkFailed.Increment()
}
putOptions(tf.opts)
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO) tcpip.Error {
tf.txHash = e.txHash
if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
@@ -728,7 +744,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
return nil
}
-func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GSO) {
optLen := len(tf.opts)
tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
pkt.TransportProtocolNumber = header.TCPProtocolNumber
@@ -745,7 +761,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size()))
// Only calculate the checksum if offloading isn't supported.
- if gso != nil && gso.NeedsCsum {
+ if gso.Type != stack.GSONone && gso.NeedsCsum {
// This is called CHECKSUM_PARTIAL in the Linux kernel. We
// calculate a checksum of the pseudo-header and save it in the
// TCP header, then the kernel calculate a checksum of the
@@ -757,7 +773,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
}
}
-func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
// We need to shallow clone the VectorisedView here as ReadToView will
// split the VectorisedView and Trim underlying views as it splits. Not
// doing the clone here will cause the underlying views of data itself
@@ -789,13 +805,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkt.Data().ReadFromVV(&data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
+ pkt.GSOOptions = gso
pkts.PushBack(pkt)
}
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
}
- sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
+ sent, err := r.WritePackets(pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
if err != nil {
r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
}
@@ -805,13 +822,13 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
optLen := len(tf.opts)
if tf.rcvWnd > math.MaxUint16 {
tf.rcvWnd = math.MaxUint16
}
- if r.Loop()&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ if r.Loop()&stack.PacketLoop == 0 && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
return sendTCPBatch(r, tf, data, gso, owner)
}
@@ -819,6 +836,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen,
Data: data,
})
+ pkt.GSOOptions = gso
pkt.Hash = tf.txHash
pkt.Owner = owner
buildTCPHdr(r, tf, pkt, gso)
@@ -826,7 +844,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
+ if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
@@ -845,7 +863,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// N.B. the ordering here matches the ordering used by Linux internally
// and described in the raw makeOptions function. We don't include
// unnecessary cases here (post connection.)
- if e.sendTSOk {
+ if e.SendTSOk {
// Embed the timestamp if timestamp has been enabled.
//
// We only use the lower 32 bits of the unix time in
@@ -862,7 +880,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
- if e.sackPermitted && len(sackBlocks) > 0 {
+ if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
@@ -884,7 +902,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
}
options := e.makeOptions(sackBlocks)
err := e.sendTCP(e.route, tcpFields{
- id: e.ID,
+ id: e.TransportEndpointInfo.ID,
ttl: e.ttl,
tos: e.sendTOS,
flags: flags,
@@ -898,9 +916,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
}
func (e *endpoint) handleWrite() {
- e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Lock()
next := e.drainSendQueueLocked()
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.sendData(next)
}
@@ -909,10 +927,10 @@ func (e *endpoint) handleWrite() {
//
// Precondition: e.sndBufMu must be locked.
func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueue.Front()
+ first := e.sndQueueInfo.sndQueue.Front()
if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueue)
- e.sndBufInQueue = 0
+ e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
+ e.sndQueueInfo.SndBufInQueue = 0
}
return first
}
@@ -936,7 +954,7 @@ func (e *endpoint) handleClose() {
e.handleWrite()
// Mark send side as closed.
- e.snd.closed = true
+ e.snd.Closed = true
}
// resetConnectionLocked puts the endpoint in an error state with the given
@@ -958,12 +976,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) {
//
// See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
// information.
- sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd)
resetSeqNum := sndWndEnd
- if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
- resetSeqNum = e.snd.sndNxt
+ if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) {
+ resetSeqNum = e.snd.SndNxt
}
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0)
}
}
@@ -989,13 +1007,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- e.rcvListMu.Lock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.setEndpointState(StateEstablished)
}
@@ -1026,10 +1044,15 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
+ ep = e.stack.FindTransportEndpoint(
+ header.IPv4ProtocolNumber,
+ e.TransProto,
+ e.TransportEndpointInfo.ID,
+ s.nicID,
+ )
}
if ep == nil {
replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
@@ -1108,7 +1131,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
}
// handleSegments processes all inbound segments.
-func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
if e.EndpointState().closed() {
@@ -1120,7 +1145,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
break
}
- cont, err := e.handleSegment(s)
+ cont, err := e.handleSegmentLocked(s)
s.decRef()
if err != nil {
return err
@@ -1138,7 +1163,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
}
// Send an ACK for all processed packets if needed.
- if e.rcv.rcvNxt != e.snd.maxSentAck {
+ if e.rcv.RcvNxt != e.snd.MaxSentAck {
e.snd.sendAck()
}
@@ -1147,18 +1172,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
return nil
}
-func (e *endpoint) probeSegment() {
- if e.probe != nil {
- e.probe(e.completeState())
+// Precondition: e.mu must be held.
+func (e *endpoint) probeSegmentLocked() {
+ if fn := e.probe; fn != nil {
+ fn(e.completeStateLocked())
}
}
// handleSegment handles a given segment and notifies the worker goroutine if
// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) {
// Invoke the tcp probe if installed. The tcp probe function will update
// the TCPEndpointState after the segment is processed.
- defer e.probeSegment()
+ defer e.probeSegmentLocked()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
@@ -1191,7 +1219,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
- s.window <<= e.snd.sndWndScale
+ s.window <<= e.snd.SndWndScale
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
@@ -1255,7 +1283,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
// seg.seq = snd.nxt-1.
e.keepalive.unacked++
e.keepalive.Unlock()
- e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1)
e.resetKeepaliveTimer(false)
return nil
}
@@ -1269,7 +1297,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt {
e.keepalive.timer.disable()
e.keepalive.Unlock()
return
@@ -1340,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Reaching this point means that we successfully completed the 3-way
- // handshake with our peer.
- //
+ // handshake with our peer. The current endpoint state could be any state
+ // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a
+ // RST from the peer via the dispatcher fast path, before the loop is
+ // started.
+ if s := e.EndpointState(); !s.connected() {
+ switch s {
+ case StateClose, StateError:
+ // If the endpoint is in CLOSED/ERROR state, sender state has to be
+ // initialized if the endpoint was previously established.
+ if e.snd != nil {
+ break
+ }
+ fallthrough
+ default:
+ panic("endpoint was not established, current state " + s.String())
+ }
+ }
+
// Completing the 3-way handshake is an indication that the route is valid
// and the remote is reachable as the only way we can complete a handshake
// is if our SYN reached the remote and their ACK reached us.
@@ -1362,14 +1406,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
f func() tcpip.Error
}{
{
- w: &e.sndWaker,
+ w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
e.handleWrite()
return nil
},
},
{
- w: &e.sndCloseWaker,
+ w: &e.sndQueueInfo.sndCloseWaker,
f: func() tcpip.Error {
e.handleClose()
return nil
@@ -1403,7 +1447,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.newSegmentWaker,
f: func() tcpip.Error {
- return e.handleSegments(false /* fastPath */)
+ return e.handleSegmentsLocked(false /* fastPath */)
},
},
{
@@ -1419,11 +1463,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyMTUChanged != 0 {
- e.sndBufMu.Lock()
- count := e.packetTooBigCount
- e.packetTooBigCount = 0
- mtu := e.sndMTU
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ count := e.sndQueueInfo.PacketTooBigCount
+ e.sndQueueInfo.PacketTooBigCount = 0
+ mtu := e.sndQueueInfo.SndMTU
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.snd.updateMaxPayloadSize(mtu, count)
}
@@ -1453,7 +1497,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(false /* fastPath */); err != nil {
+ if err := e.handleSegmentsLocked(false /* fastPath */); err != nil {
return err
}
}
@@ -1504,11 +1548,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.newSegmentWaker.Assert()
}
- e.rcvListMu.Lock()
- if !e.rcvList.Empty() {
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if !e.rcvQueueInfo.rcvQueue.Empty() {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index 1975f1a44..962f1d687 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -17,6 +17,8 @@ package tcp
import (
"math"
"time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// cubicState stores the variables related to TCP CUBIC congestion
@@ -25,47 +27,12 @@ import (
// See: https://tools.ietf.org/html/rfc8312.
// +stateify savable
type cubicState struct {
- // wLastMax is the previous wMax value.
- wLastMax float64
-
- // wMax is the value of the congestion window at the
- // time of last congestion event.
- wMax float64
-
- // t denotes the time when the current congestion avoidance
- // was entered.
- t time.Time `state:".(unixTime)"`
+ stack.TCPCubicState
// numCongestionEvents tracks the number of congestion events since last
// RTO.
numCongestionEvents int
- // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
- // per RFC.
- c float64
-
- // k is the time period that the above function takes to increase the
- // current window size to W_max if there are no further congestion
- // events and is calculated using the following equation:
- //
- // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
- k float64
-
- // beta is the CUBIC multiplication decrease factor. that is, when a
- // congestion event is detected, CUBIC reduces its cwnd to
- // W_cubic(0)=W_max*beta_cubic.
- beta float64
-
- // wC is window computed by CUBIC at time t. It's calculated using the
- // formula:
- //
- // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
- wC float64
-
- // wEst is the window computed by CUBIC at time t+RTT i.e
- // W_cubic(t+RTT).
- wEst float64
-
s *sender
}
@@ -73,10 +40,12 @@ type cubicState struct {
// beta and c set and t set to current time.
func newCubicCC(s *sender) *cubicState {
return &cubicState{
- t: time.Now(),
- beta: 0.7,
- c: 0.4,
- s: s,
+ TCPCubicState: stack.TCPCubicState{
+ T: time.Now(),
+ Beta: 0.7,
+ C: 0.4,
+ },
+ s: s,
}
}
@@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() {
// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
// https://tools.ietf.org/html/rfc8312#section-4.8
if c.numCongestionEvents == 0 {
- c.k = 0
- c.t = time.Now()
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.K = 0
+ c.T = time.Now()
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
}
}
@@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() {
func (c *cubicState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
- newcwnd := c.s.sndCwnd + packetsAcked
+ newcwnd := c.s.SndCwnd + packetsAcked
enterCA := false
- if newcwnd >= c.s.sndSsthresh {
- newcwnd = c.s.sndSsthresh
- c.s.sndCAAckCount = 0
+ if newcwnd >= c.s.Ssthresh {
+ newcwnd = c.s.Ssthresh
+ c.s.SndCAAckCount = 0
enterCA = true
}
- packetsAcked -= newcwnd - c.s.sndCwnd
- c.s.sndCwnd = newcwnd
+ packetsAcked -= newcwnd - c.s.SndCwnd
+ c.s.SndCwnd = newcwnd
if enterCA {
c.enterCongestionAvoidance()
}
@@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int {
// ACK received.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) Update(packetsAcked int) {
- if c.s.sndCwnd < c.s.sndSsthresh {
+ if c.s.SndCwnd < c.s.Ssthresh {
packetsAcked = c.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
}
} else {
c.s.rtt.Lock()
- srtt := c.s.rtt.srtt
+ srtt := c.s.rtt.TCPRTTState.SRTT
c.s.rtt.Unlock()
- c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt)
}
}
// cubicCwnd computes the CUBIC congestion window after t seconds from last
// congestion event.
func (c *cubicState) cubicCwnd(t float64) float64 {
- return c.c*math.Pow(t, 3.0) + c.wMax
+ return c.C*math.Pow(t, 3.0) + c.WMax
}
// getCwnd returns the current congestion window as computed by CUBIC.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
- elapsed := time.Since(c.t).Seconds()
+ elapsed := time.Since(c.T).Seconds()
// Compute the window as per Cubic after 'elapsed' time
// since last congestion event.
- c.wC = c.cubicCwnd(elapsed - c.k)
+ c.WC = c.cubicCwnd(elapsed - c.K)
// Compute the TCP friendly estimate of the congestion window.
- c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+ c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds())
// Make sure in the TCP friendly region CUBIC performs at least
// as well as Reno.
- if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ if c.WC < c.WEst && float64(sndCwnd) < c.WEst {
// TCP Friendly region of cubic.
- return int(c.wEst)
+ return int(c.WEst)
}
// In Concave/Convex region of CUBIC, calculate what CUBIC window
// will be after 1 RTT and use that to grow congestion window
// for every ack.
- tEst := (time.Since(c.t) + srtt).Seconds()
- wtRtt := c.cubicCwnd(tEst - c.k)
+ tEst := (time.Since(c.T) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.K)
// As per 4.3 for each received ACK cwnd must be incremented
// by (w_cubic(t+RTT) - cwnd/cwnd.
cwnd := float64(sndCwnd)
@@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
func (c *cubicState) HandleLossDetected() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
- c.t = time.Now()
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.T = time.Now()
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
c.fastConvergence()
c.reduceSlowStartThreshold()
@@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() {
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
- c.t = time.Now()
+ c.T = time.Now()
c.numCongestionEvents = 0
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
c.fastConvergence()
@@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() {
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
- c.s.sndCwnd = 1
+ c.s.SndCwnd = 1
}
// fastConvergence implements the logic for Fast Convergence algorithm as
// described in https://tools.ietf.org/html/rfc8312#section-4.6.
func (c *cubicState) fastConvergence() {
- if c.wMax < c.wLastMax {
- c.wLastMax = c.wMax
- c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ if c.WMax < c.WLastMax {
+ c.WLastMax = c.WMax
+ c.WMax = c.WMax * (1.0 + c.Beta) / 2.0
} else {
- c.wLastMax = c.wMax
+ c.WLastMax = c.WMax
}
// Recompute k as wMax may have changed.
- c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+ c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C)
}
// PostRecovery implemements congestionControl.PostRecovery.
func (c *cubicState) PostRecovery() {
- c.t = time.Now()
+ c.T = time.Now()
}
// reduceSlowStartThreshold returns new SsThresh as described in
// https://tools.ietf.org/html/rfc8312#section-4.7.
func (c *cubicState) reduceSlowStartThreshold() {
- c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+ c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0))
}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 21162f01a..512053a04 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
// If the endpoint is in a connected state then we do direct delivery
// to ensure low latency and avoid scheduler interactions.
- switch err := ep.handleSegments(true /* fastPath */); {
+ switch err := ep.handleSegmentsLocked(true /* fastPath */); {
case err != nil:
// Send any active resets if required.
ep.resetConnectionLocked(err)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index f6a16f96e..f148d505d 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
// Start connection attempt, it must fail.
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) {
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
_, _, err := c.EP.Accept(&addr)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -565,17 +566,15 @@ func TestV4AcceptOnV4(t *testing.T) {
}
func testV4ListenClose(t *testing.T, c *context.Context) {
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- const n = uint16(32)
+ const n = 32
// Start listening.
- if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ if err := c.EP.Listen(n); err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -591,9 +590,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
})
}
- // Each of these ACK's will cause a syn-cookie based connection to be
+ // Each of these ACKs will cause a syn-cookie based connection to be
// accepted and delivered to the listening endpoint.
- for i := uint16(0); i < n; i++ {
+ for i := 0; i < n; i++ {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
@@ -613,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.ReadableEvents)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c5daba232..90edcfba6 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "container/list"
"encoding/binary"
"fmt"
"io"
@@ -190,42 +191,6 @@ type SACKInfo struct {
NumBlocks int
}
-// rcvBufAutoTuneParams are used to hold state variables to compute
-// the auto tuned recv buffer size.
-//
-// +stateify savable
-type rcvBufAutoTuneParams struct {
- // measureTime is the time at which the current measurement
- // was started.
- measureTime time.Time `state:".(unixTime)"`
-
- // copied is the number of bytes copied out of the receive
- // buffers since this measure began.
- copied int
-
- // prevCopied is the number of bytes copied out of the receive
- // buffers in the previous RTT period.
- prevCopied int
-
- // rtt is the non-smoothed minimum RTT as measured by observing the time
- // between when a byte is first acknowledged and the receipt of data
- // that is at least one window beyond the sequence number that was
- // acknowledged.
- rtt time.Duration
-
- // rttMeasureSeqNumber is the highest acceptable sequence number at the
- // time this RTT measurement period began.
- rttMeasureSeqNumber seqnum.Value
-
- // rttMeasureTime is the absolute time at which the current rtt
- // measurement period began.
- rttMeasureTime time.Time `state:".(unixTime)"`
-
- // disabled is true if an explicit receive buffer is set for the
- // endpoint.
- disabled bool
-}
-
// ReceiveErrors collect segment receive errors within transport layer.
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -246,7 +211,7 @@ type ReceiveErrors struct {
ListenOverflowAckDrop tcpip.StatCounter
// ZeroRcvWindowState is the number of times we advertised
- // a zero receive window when rcvList is full.
+ // a zero receive window when rcvQueue is full.
ZeroRcvWindowState tcpip.StatCounter
// WantZeroWindow is the number of times we wanted to advertise a
@@ -309,18 +274,45 @@ type Stats struct {
// marker interface.
func (*Stats) IsEndpointStats() {}
-// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools. This exists to allow tcp-only state to
-// be exposed.
+// sndQueueInfo implements a send queue.
//
// +stateify savable
-type EndpointInfo struct {
- stack.TransportEndpointInfo
+type sndQueueInfo struct {
+ sndQueueMu sync.Mutex `state:"nosave"`
+ stack.TCPSndBufState
+
+ // sndQueue holds segments that are ready to be sent.
+ sndQueue segmentList `state:"wait"`
+
+ // sndWaker is used to signal the protocol goroutine when segments are
+ // added to the `sndQueue`.
+ sndWaker sleep.Waker `state:"manual"`
+
+ // sndCloseWaker is used to notify the protocol goroutine when the send
+ // side is closed.
+ sndCloseWaker sleep.Waker `state:"manual"`
}
-// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
-// marker interface.
-func (*EndpointInfo) IsEndpointInfo() {}
+// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
+//
+// +stateify savable
+type rcvQueueInfo struct {
+ rcvQueueMu sync.Mutex `state:"nosave"`
+ stack.TCPRcvBufState
+
+ // rcvQueue is the queue for ready-for-delivery segments. This struct's
+ // mutex must be held in order append segments to list.
+ rcvQueue segmentList `state:"wait"`
+}
+
+// +stateify savable
+type accepted struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
+ cap int
+}
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
@@ -337,9 +329,9 @@ func (*EndpointInfo) IsEndpointInfo() {}
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects acceptedChan.
-// e.rcvListMu -> Protects the rcvList and associated fields.
-// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.acceptMu -> protects accepted.
+// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
+// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
//
// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
@@ -362,7 +354,8 @@ func (*EndpointInfo) IsEndpointInfo() {}
//
// +stateify savable
type endpoint struct {
- EndpointInfo
+ stack.TCPEndpointStateInner
+ stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
@@ -395,38 +388,23 @@ type endpoint struct {
// rcvReadMu synchronizes calls to Read.
//
- // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu
// must be held during each read to ensure atomicity, so that multiple reads
// do not interleave.
//
// rcvReadMu should be held before holding mu.
rcvReadMu sync.Mutex `state:"nosave"`
- // rcvListMu synchronizes access to rcvList.
- //
- // rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
-
- // rcvList is the queue for ready-for-delivery segments.
- //
- // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
- // and removing segments from list. A range of segment can be determined, then
- // temporarily release mu and rcvListMu while processing the segment range.
- // This allows new segments to be appended to the list while processing.
- //
- // rcvListMu must be held to append segments to list.
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- // rcvBufSize is the total size of the receive buffer.
- rcvBufSize int
- // rcvBufUsed is the actual number of payload bytes held in the receive buffer
- // not counting any overheads of the segments itself. NOTE: This will always
- // be strictly <= rcvMemUsed below.
- rcvBufUsed int
- rcvAutoParams rcvBufAutoTuneParams
+ // rcvQueueInfo holds the implementation of the endpoint's receive buffer.
+ // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu,
+ // and rcvQueueMu are held, in that stated order. While processing the segment
+ // range, you can determine a range and then temporarily release mu and
+ // rcvQueueMu, which allows new segments to be appended to the queue while
+ // processing.
+ rcvQueueInfo rcvQueueInfo
// rcvMemUsed tracks the total amount of memory in use by received segments
- // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to
// compute the window and the actual available buffer space. This is distinct
// from rcvBufUsed above which is the actual number of payload bytes held in
// the buffer not including any segment overheads.
@@ -488,33 +466,16 @@ type endpoint struct {
// also true, and they're both protected by the mutex.
workerCleanup bool
- // sendTSOk is used to indicate when the TS Option has been negotiated.
- // When sendTSOk is true every non-RST segment should carry a TS as per
- // RFC7323#section-1.1
- sendTSOk bool
-
- // recentTS is the timestamp that should be sent in the TSEcr field of
- // the timestamp for future segments sent by the endpoint. This field is
- // updated if required when a new segment is received by this endpoint.
- recentTS uint32
-
- // recentTSTime is the unix time when we updated recentTS last.
+ // recentTSTime is the unix time when we last updated
+ // TCPEndpointStateInner.RecentTS.
recentTSTime time.Time `state:".(unixTime)"`
- // tsOffset is a randomized offset added to the value of the
- // TSVal field in the timestamp option.
- tsOffset uint32
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
// tcpRecovery is the loss deteoction algorithm used by TCP.
tcpRecovery tcpip.TCPRecovery
- // sackPermitted is set to true if the peer sends the TCPSACKPermitted
- // option in the SYN/SYN-ACK.
- sackPermitted bool
-
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
@@ -550,32 +511,13 @@ type endpoint struct {
// this value.
windowClamp uint32
- // The following fields are used to manage the send buffer. When
- // segments are ready to be sent, they are added to sndQueue and the
- // protocol goroutine is signaled via sndWaker.
- //
- // When the send side is closed, the protocol goroutine is notified via
- // sndCloseWaker, and sndClosed is set to true.
- sndBufMu sync.Mutex `state:"nosave"`
- sndBufUsed int
- sndClosed bool
- sndBufInQueue seqnum.Size
- sndQueue segmentList `state:"wait"`
- sndWaker sleep.Waker `state:"manual"`
- sndCloseWaker sleep.Waker `state:"manual"`
+ // sndQueueInfo contains the implementation of the endpoint's send queue.
+ sndQueueInfo sndQueueInfo
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
cc tcpip.CongestionControlOption
- // The following are used when a "packet too big" control packet is
- // received. They are protected by sndBufMu. They are used to
- // communicate to the main protocol goroutine how many such control
- // messages have been received since the last notification was processed
- // and what was the smallest MTU seen.
- packetTooBigCount int
- sndMTU int
-
// newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it.
newSegmentWaker sleep.Waker `state:"manual"`
@@ -607,33 +549,26 @@ type endpoint struct {
// listener.
deferAccept time.Duration
- // pendingAccepted is a synchronization primitive used to track number
- // of connections that are queued up to be delivered to the accepted
- // channel. We use this to ensure that all goroutines blocked on writing
- // to the acceptedChan below terminate before we close acceptedChan.
+ // pendingAccepted tracks connections queued to be accepted. It is used to
+ // ensure such queued connections are terminated before the accepted queue is
+ // marked closed (by setting its capacity to zero).
pendingAccepted sync.WaitGroup `state:"nosave"`
- // acceptMu protects acceptedChan.
+ // acceptMu protects accepted.
acceptMu sync.Mutex `state:"nosave"`
// acceptCond is a condition variable that can be used to block on when
- // acceptedChan is full and an endpoint is ready to be delivered.
- //
- // This condition variable is required because just blocking on sending
- // to acceptedChan does not work in cases where endpoint.Listen is
- // called twice with different backlog values. In such cases the channel
- // is closed and a new one created. Any pending goroutines blocking on
- // the write to the channel will panic.
+ // accepted is full and an endpoint is ready to be delivered.
//
// We use this condition variable to block/unblock goroutines which
// tried to deliver an endpoint but couldn't because accept backlog was
// full ( See: endpoint.deliverAccepted ).
acceptCond *sync.Cond `state:"nosave"`
- // acceptedChan is used by a listening endpoint protocol goroutine to
+ // accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- acceptedChan chan *endpoint `state:".([]*endpoint)"`
+ accepted accepted
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -664,7 +599,7 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- gso *stack.GSO
+ gso stack.GSO
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats Stats `state:"nosave"`
@@ -779,7 +714,7 @@ func (e *endpoint) UnlockUser() {
switch e.EndpointState() {
case StateEstablished:
- if err := e.handleSegments(true /* fastPath */); err != nil {
+ if err := e.handleSegmentsLocked(true /* fastPath */); err != nil {
e.notifyProtocolGoroutine(notifyTickleWorker)
}
default:
@@ -839,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState {
// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- e.recentTS = recentTS
+ e.RecentTS = recentTS
e.recentTSTime = time.Now()
}
// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return e.recentTS
+ return e.RecentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -865,16 +800,17 @@ type keepalive struct {
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: s,
- EndpointInfo: EndpointInfo{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ sndQueueInfo: sndQueueInfo{
+ TCPSndBufState: stack.TCPSndBufState{
+ SndMTU: int(math.MaxInt32),
},
},
waiterQueue: waiterQueue,
state: StateInitial,
- rcvBufSize: DefaultReceiveBufferSize,
- sndMTU: int(math.MaxInt32),
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -886,10 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetQuickAck(true)
e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */)
+ e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -898,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- e.rcvBufSize = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
var cs tcpip.CongestionControlOption
@@ -908,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var mrb tcpip.TCPModerateReceiveBufferOption
if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
- e.rcvAutoParams.disabled = !bool(mrb)
+ e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb)
}
var de tcpip.TCPDelayEnabled
@@ -933,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.tsOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(&e.keepalive.waker)
@@ -959,10 +896,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result = mask
case StateListen:
- // Check if there's anything in the accepted channel.
+ // Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if len(e.acceptedChan) > 0 {
+ if e.accepted.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -971,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.WritableEvents) != 0 {
- e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Lock()
sndBufSize := e.getSendBufferSize()
- if e.sndClosed || e.sndBufUsed < sndBufSize {
+ if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize {
result |= waiter.WritableEvents
}
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
}
// Determine if the endpoint is readable if requested.
if (mask & waiter.ReadableEvents) != 0 {
- e.rcvListMu.Lock()
- if e.rcvBufUsed > 0 || e.rcvClosed {
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed {
result |= waiter.ReadableEvents
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
}
}
@@ -1093,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
@@ -1145,22 +1082,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
+ acceptedCopy := e.accepted
+ e.accepted = accepted{}
+ e.acceptMu.Unlock()
+
+ if acceptedCopy == (accepted{}) {
return
}
- close(e.acceptedChan)
- ch := e.acceptedChan
- e.acceptedChan = nil
+
e.acceptCond.Broadcast()
- e.acceptMu.Unlock()
// Reset all connections that are waiting to be accepted.
- for n := range ch {
- n.notifyProtocolGoroutine(notifyReset)
+ for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
// Wait for reset of all endpoints that are still waiting to be delivered to
- // the now closed acceptedChan.
+ // the now closed accepted.
e.pendingAccepted.Wait()
}
@@ -1176,7 +1113,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1184,8 +1121,8 @@ func (e *endpoint) cleanupLocked() {
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
@@ -1247,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- if e.rcvAutoParams.disabled {
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if e.rcvQueueInfo.RcvAutoParams.Disabled {
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
now := time.Now()
- if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
- e.rcvAutoParams.copied += copied
- e.rcvListMu.Unlock()
+ if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt {
+ e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- prevRTTCopied := e.rcvAutoParams.copied + copied
- prevCopied := e.rcvAutoParams.prevCopied
+ prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied
+ prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes
rcvWnd := 0
if prevRTTCopied > prevCopied {
// The minimal receive window based on what was copied by the app
@@ -1291,24 +1228,25 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// We do not adjust downwards as that can cause the receiver to
// reject valid data that might already be in flight as the
// acceptable window will shrink.
- if rcvWnd > e.rcvBufSize {
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvBufSize = rcvWnd
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ rcvBufSize := int(e.ops.GetReceiveBufferSize())
+ if rcvWnd > rcvBufSize {
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd))
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
- // We only update prevCopied when we grow the buffer because in cases
- // where prevCopied > prevRTTCopied the existing buffer is already big
+ // We only update PrevCopiedBytes when we grow the buffer because in cases
+ // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big
// enough to handle the current rate and we don't need to do any
// adjustments.
- e.rcvAutoParams.prevCopied = prevRTTCopied
+ e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied
}
- e.rcvAutoParams.measureTime = now
- e.rcvAutoParams.copied = 0
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.RcvAutoParams.MeasureTime = now
+ e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
}
// SetOwner implements tcpip.Endpoint.SetOwner.
@@ -1342,6 +1280,12 @@ func (e *endpoint) LastError() tcpip.Error {
return e.lastErrorLocked()
}
+// LastErrorLocked reads and clears lastError with e.mu held.
+// Only to be used in tests.
+func (e *endpoint) LastErrorLocked() tcpip.Error {
+ return e.lastErrorLocked()
+}
+
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.LockUser()
@@ -1357,7 +1301,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
defer e.rcvReadMu.Unlock()
// N.B. Here we get a range of segments to be processed. It is safe to not
- // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we
// can remove segments from the list through commitRead().
first, last, serr := e.startRead()
if serr != nil {
@@ -1429,10 +1373,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
- bufUsed := e.rcvBufUsed
+ bufUsed := e.rcvQueueInfo.RcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
@@ -1444,14 +1388,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
return nil, nil, &tcpip.ErrNotConnected{}
}
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
+ if e.rcvQueueInfo.RcvBufUsed == 0 {
+ if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() {
return nil, nil, &tcpip.ErrClosedForReceive{}
}
return nil, nil, &tcpip.ErrWouldBlock{}
}
- return e.rcvList.Front(), e.rcvList.Back(), nil
+ return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil
}
// commitRead commits a read of done bytes and returns the next non-empty
@@ -1467,39 +1411,39 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
func (e *endpoint) commitRead(done int) *segment {
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
memDelta := 0
- s := e.rcvList.Front()
+ s := e.rcvQueueInfo.rcvQueue.Front()
for s != nil && s.data.Size() == 0 {
- e.rcvList.Remove(s)
+ e.rcvQueueInfo.rcvQueue.Remove(s)
// Memory is only considered released when the whole segment has been
// read.
memDelta += s.segMemSize()
s.decRef()
- s = e.rcvList.Front()
+ s = e.rcvQueueInfo.rcvQueue.Front()
}
- e.rcvBufUsed -= done
+ e.rcvQueueInfo.RcvBufUsed -= done
if memDelta > 0 {
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
- return e.rcvList.Front()
+ return e.rcvQueueInfo.rcvQueue.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
// and also returns the number of bytes that can be written at this
// moment. If the endpoint is not writable then it returns an error
// indicating the reason why it's not writable.
-// Caller must hold e.mu and e.sndBufMu
+// Caller must hold e.mu and e.sndQueueMu
func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
@@ -1519,12 +1463,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
}
// Check if the connection has already been closed for sends.
- if e.sndClosed {
+ if e.sndQueueInfo.SndClosed {
return 0, &tcpip.ErrClosedForSend{}
}
sndBufSize := e.getSendBufferSize()
- avail := sndBufSize - e.sndBufUsed
+ avail := sndBufSize - e.sndQueueInfo.SndBufUsed
if avail <= 0 {
return 0, &tcpip.ErrWouldBlock{}
}
@@ -1541,8 +1485,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
defer e.UnlockUser()
nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndBufMu.Lock()
- defer e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
@@ -1557,8 +1501,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// available buffer space to be consumed by some other caller while we
// are copying data in.
if !opts.Atomic {
- e.sndBufMu.Unlock()
- defer e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
e.UnlockUser()
defer e.LockUser()
@@ -1600,10 +1544,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
// Add data to the send queue.
- s := newOutgoingSegment(e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
+ e.sndQueueInfo.sndQueue.PushBack(s)
return e.drainSendQueueLocked(), len(v), nil
}()
@@ -1618,11 +1562,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
-// Precondition: e.mu and e.rcvListMu must be held.
-func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
- wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
- maxWindow := wndFromSpace(e.rcvBufSize)
- wndFromUsedBytes := maxWindow - e.rcvBufUsed
+// Precondition: e.mu and e.rcvQueueMu must be held.
+func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ maxWindow := wndFromSpace(rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed
// We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
// cases where we receive a lot of small segments the segment overhead is a
@@ -1640,11 +1584,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
return seqnum.Size(newWnd)
}
-// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu.
func (e *endpoint) selectWindow() (wnd seqnum.Size) {
- e.rcvListMu.Lock()
- wnd = e.selectWindowLocked()
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize()))
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return wnd
}
@@ -1662,9 +1606,9 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) {
// above will be true if the new window is >= ACK threshold and false
// otherwise.
//
-// Precondition: e.mu and e.rcvListMu must be held.
-func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := int(e.selectWindowLocked())
+// Precondition: e.mu and e.rcvQueueMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) {
+ newAvail := int(e.selectWindowLocked(rcvBufSize))
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -1673,7 +1617,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
// rcvBufFraction is the inverse of the fraction of receive buffer size that
// is used to decide if the available buffer space is now above it.
const rcvBufFraction = 2
- if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold {
threshold = wndThreshold
}
switch {
@@ -1700,7 +1644,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
}
// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
-func (e *endpoint) OnKeepAliveSet(v bool) {
+func (e *endpoint) OnKeepAliveSet(bool) {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
}
@@ -1708,7 +1652,7 @@ func (e *endpoint) OnKeepAliveSet(v bool) {
func (e *endpoint) OnDelayOptionSet(v bool) {
if !v {
// Handle delayed data.
- e.sndWaker.Assert()
+ e.sndQueueInfo.sndWaker.Assert()
}
}
@@ -1716,7 +1660,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) {
func (e *endpoint) OnCorkOptionSet(v bool) {
if !v {
// Handle the corked data.
- e.sndWaker.Assert()
+ e.sndQueueInfo.sndWaker.Assert()
}
}
@@ -1724,6 +1668,37 @@ func (e *endpoint) getSendBufferSize() int {
return int(e.ops.GetSendBufferSize())
}
+// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize.
+func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
+ e.LockUser()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.RcvWndScale
+ }
+ if rcvBufSz>>scale == 0 {
+ rcvBufSz = 1 << scale
+ }
+
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz)))
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz)))
+ e.rcvQueueInfo.RcvAutoParams.Disabled = true
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+ e.UnlockUser()
+ return rcvBufSz
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -1767,56 +1742,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs tcpip.TCPReceiveBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
- panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
- }
-
- if v > rs.Max {
- v = rs.Max
- }
-
- if v < math.MaxInt32/SegOverheadFactor {
- v *= SegOverheadFactor
- if v < rs.Min {
- v = rs.Min
- }
- } else {
- v = math.MaxInt32
- }
-
- e.LockUser()
- e.rcvListMu.Lock()
-
- // Make sure the receive buffer size allows us to send a
- // non-zero window size.
- scale := uint8(0)
- if e.rcv != nil {
- scale = e.rcv.rcvWndScale
- }
- if v>>scale == 0 {
- v = 1 << scale
- }
-
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvBufSize = v
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
-
- e.rcvAutoParams.disabled = true
-
- // Immediately send an ACK to uncork the sender silly window
- // syndrome prevetion, when our available space grows above aMSS
- // or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
- }
-
- e.rcvListMu.Unlock()
- e.UnlockUser()
-
case tcpip.TTLOption:
e.LockUser()
e.ttl = uint8(v)
@@ -1959,10 +1884,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) {
return 0, &tcpip.ErrInvalidEndpointState{}
}
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
- return e.rcvBufUsed, nil
+ return e.rcvQueueInfo.RcvBufUsed, nil
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -2002,12 +1927,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
- case tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- v := e.rcvBufSize
- e.rcvListMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.LockUser()
v := int(e.ttl)
@@ -2043,15 +1962,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
// the connection did not send and receive data, then RTT will
// be zero.
snd.rtt.Lock()
- info.RTT = snd.rtt.srtt
- info.RTTVar = snd.rtt.rttvar
+ info.RTT = snd.rtt.TCPRTTState.SRTT
+ info.RTTVar = snd.rtt.TCPRTTState.RTTVar
snd.rtt.Unlock()
- info.RTO = snd.rto
+ info.RTO = snd.RTO
info.CcState = snd.state
- info.SndSsthresh = uint32(snd.sndSsthresh)
- info.SndCwnd = uint32(snd.sndCwnd)
- info.ReorderSeen = snd.rc.reorderSeen
+ info.SndSsthresh = uint32(snd.Ssthresh)
+ info.SndCwnd = uint32(snd.SndCwnd)
+ info.ReorderSeen = snd.rc.Reord
}
e.UnlockUser()
return info
@@ -2096,7 +2015,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
e.UnlockUser()
if err != nil {
return err
@@ -2204,20 +2123,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- e.ID.LocalAddress = r.LocalAddress()
- e.ID.RemoteAddress = r.RemoteAddress()
- e.ID.RemotePort = addr.Port
+ e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress()
+ e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress()
+ e.TransportEndpointInfo.ID.RemotePort = addr.Port
- if e.ID.LocalPort != 0 {
+ if e.TransportEndpointInfo.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2226,19 +2145,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// one. Make sure that it isn't one that will result in the same
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
- sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+ sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress
// Calculate a port offset based on the destination IP/port and
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.Seed())
- h.Write([]byte(e.ID.LocalAddress))
- h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h.Write(portBuf)
- portOffset := uint16(h.Sum32())
+
+ h := jenkins.Sum32(e.stack.Seed())
+ for _, s := range [][]byte{
+ []byte(e.ID.LocalAddress),
+ []byte(e.ID.RemoteAddress),
+ portBuf,
+ } {
+ // Per io.Writer.Write:
+ //
+ // Write must return a non-nil error if it returns n < len(p).
+ if _, err := h.Write(s); err != nil {
+ panic(err)
+ }
+ }
+ portOffset := h.Sum32()
var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@@ -2249,21 +2178,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
switch netProto {
case header.IPv4ProtocolNumber:
- reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress)
case header.IPv6ProtocolNumber:
- reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback
}
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) {
- if sameAddr && p == e.ID.RemotePort {
+ if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort {
return false, nil
}
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2273,7 +2202,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
- transEPID := e.ID
+ transEPID := e.TransportEndpointInfo.ID
transEPID.LocalPort = p
// Check if an endpoint is registered with demuxer in TIME-WAIT and if
// we can reuse it. If we can't find a transport endpoint then we just
@@ -2310,7 +2239,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2321,13 +2250,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
}
}
- id := e.ID
+ id := e.TransportEndpointInfo.ID
id.LocalPort = p
if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2342,13 +2271,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// Port picking successful. Save the details of
// the selected port.
- e.ID = id
+ e.TransportEndpointInfo.ID = id
e.isPortReserved = true
e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
}); err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}
}
@@ -2367,10 +2297,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
- s.id = e.ID
- e.sndWaker.Assert()
+ s.id = e.TransportEndpointInfo.ID
+ e.sndQueueInfo.sndWaker.Assert()
}
}
e.segmentQueue.mu.Unlock()
@@ -2412,10 +2342,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Close for read.
if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
- e.rcvListMu.Lock()
- e.rcvClosed = true
- rcvBufUsed := e.rcvBufUsed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = true
+ rcvBufUsed := e.rcvQueueInfo.RcvBufUsed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
@@ -2429,10 +2359,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Close for write.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- e.sndBufMu.Lock()
- if e.sndClosed {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ if e.sndQueueInfo.SndClosed {
// Already closed.
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
if e.EndpointState() == StateTimeWait {
return &tcpip.ErrNotConnected{}
}
@@ -2440,12 +2370,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Queue fin segment.
- s := newOutgoingSegment(e.ID, nil)
- e.sndQueue.PushBack(s)
- e.sndBufInQueue++
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil)
+ e.sndQueueInfo.sndQueue.PushBack(s)
+ e.sndQueueInfo.SndBufInQueue++
// Mark endpoint as closed.
- e.sndClosed = true
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.SndClosed = true
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.handleClose()
}
@@ -2458,9 +2388,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
//
// By not removing this endpoint from the demuxer mapping, we
// ensure that any other bind to the same port fails, as on Linux.
- e.rcvListMu.Lock()
- e.rcvClosed = true
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = true
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.closePendingAcceptableConnectionsLocked()
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
@@ -2491,28 +2421,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.acceptedChan == nil {
+ if e.accepted == (accepted{}) {
// listen is called after shutdown.
- e.acceptedChan = make(chan *endpoint, backlog)
+ e.accepted.cap = backlog
e.shutdownFlags = 0
- e.rcvListMu.Lock()
- e.rcvClosed = false
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
} else {
- // Adjust the size of the channel iff we can fix
+ // Adjust the size of the backlog iff we can fit
// existing pending connections into the new one.
- if len(e.acceptedChan) > backlog {
+ if e.accepted.endpoints.Len() > backlog {
return &tcpip.ErrInvalidEndpointState{}
}
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
- }
+ e.accepted.cap = backlog
}
// Notify any blocked goroutines that they can attempt to
@@ -2538,19 +2460,19 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
e.isRegistered = true
e.setEndpointState(StateListen)
- // The channel may be non-nil when we're restoring the endpoint, and it
+ // The queue may be non-zero when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptedChan = make(chan *endpoint, backlog)
+ if e.accepted == (accepted{}) {
+ e.accepted.cap = backlog
}
e.acceptMu.Unlock()
@@ -2578,24 +2500,25 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- rcvClosed := e.rcvClosed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := e.rcvQueueInfo.RcvClosed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, &tcpip.ErrInvalidEndpointState{}
}
// Get the new accepted endpoint.
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
var n *endpoint
- select {
- case n = <-e.acceptedChan:
- e.acceptCond.Signal()
- default:
+ e.acceptMu.Lock()
+ if element := e.accepted.endpoints.Front(); element != nil {
+ n = e.accepted.endpoints.Remove(element).(*endpoint)
+ }
+ e.acceptMu.Unlock()
+ if n == nil {
return nil, nil, &tcpip.ErrWouldBlock{}
}
+ e.acceptCond.Signal()
if peerAddr != nil {
*peerAddr = n.getRemoteAddress()
}
@@ -2645,7 +2568,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
if nic == 0 {
return &tcpip.ErrBadLocalAddress{}
}
- e.ID.LocalAddress = addr.Addr
+ e.TransportEndpointInfo.ID.LocalAddress = addr.Addr
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
@@ -2659,7 +2582,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
Dest: tcpip.FullAddress{},
}
port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
- id := e.ID
+ id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
// listening endpoint bound with the same id and portFlags and bindToDevice
@@ -2675,6 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
return true, nil
})
if err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}
@@ -2684,7 +2608,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
e.boundNICID = nic
e.isPortReserved = true
e.effectiveNetProtos = netProtos
- e.ID.LocalPort = port
+ e.TransportEndpointInfo.ID.LocalPort = port
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2698,8 +2622,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
defer e.UnlockUser()
return tcpip.FullAddress{
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
NIC: e.boundNICID,
}, nil
}
@@ -2718,8 +2642,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
NIC: e.boundNICID,
}
}
@@ -2758,13 +2682,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
Payload: pkt.Data().AsRange().ToOwnedView(),
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -2777,12 +2701,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
// HandleError implements stack.TransportEndpoint.
func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) {
handlePacketTooBig := func(mtu uint32) {
- e.sndBufMu.Lock()
- e.packetTooBigCount++
- if v := int(mtu); v < e.sndMTU {
- e.sndMTU = v
+ e.sndQueueInfo.sndQueueMu.Lock()
+ e.sndQueueInfo.PacketTooBigCount++
+ if v := int(mtu); v < e.sndQueueInfo.SndMTU {
+ e.sndQueueInfo.SndMTU = v
}
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
}
@@ -2801,14 +2725,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// in the send buffer. The number of newly available bytes is v.
func (e *endpoint) updateSndBufferUsage(v int) {
sendBufferSize := e.getSendBufferSize()
- e.sndBufMu.Lock()
- notify := e.sndBufUsed >= sendBufferSize>>1
- e.sndBufUsed -= v
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
+ e.sndQueueInfo.SndBufUsed -= v
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndBufUsed < int(sendBufferSize)>>1
- e.sndBufMu.Unlock()
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
e.waiterQueue.Notify(waiter.WritableEvents)
@@ -2819,58 +2743,50 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
- e.rcvListMu.Lock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
if s != nil {
- e.rcvBufUsed += s.payloadSize()
+ e.rcvQueueInfo.RcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvList.PushBack(s)
+ e.rcvQueueInfo.rcvQueue.PushBack(s)
} else {
- e.rcvClosed = true
+ e.rcvQueueInfo.RcvClosed = true
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.waiterQueue.Notify(waiter.ReadableEvents)
}
// receiveBufferAvailableLocked calculates how many bytes are still available
// in the receive buffer.
-// rcvListMu must be held when this function is called.
-func (e *endpoint) receiveBufferAvailableLocked() int {
+// rcvQueueMu must be held when this function is called.
+func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
memUsed := e.receiveMemUsed()
- if memUsed >= e.rcvBufSize {
+ if memUsed >= rcvBufSize {
return 0
}
- return e.rcvBufSize - memUsed
+ return rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
// receive buffer based on the actual memory used by all segments held in
// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
- e.rcvListMu.Lock()
- available := e.receiveBufferAvailableLocked()
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize()))
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return available
}
// receiveBufferUsed returns the amount of in-use receive buffer.
func (e *endpoint) receiveBufferUsed() int {
- e.rcvListMu.Lock()
- used := e.rcvBufUsed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ used := e.rcvQueueInfo.RcvBufUsed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return used
}
-// receiveBufferSize returns the current size of the receive buffer.
-func (e *endpoint) receiveBufferSize() int {
- e.rcvListMu.Lock()
- size := e.rcvBufSize
- e.rcvListMu.Unlock()
- return size
-}
-
// receiveMemUsed returns the total memory in use by segments held by this
// endpoint.
func (e *endpoint) receiveMemUsed() int {
@@ -2899,11 +2815,11 @@ func (e *endpoint) maxReceiveBufferSize() int {
// receiveBuffer otherwise we use the max permissible receive buffer size to
// compute the scale.
func (e *endpoint) rcvWndScaleForHandshake() int {
- bufSizeForScale := e.receiveBufferSize()
+ bufSizeForScale := e.ops.GetReceiveBufferSize()
- e.rcvListMu.Lock()
- autoTuningDisabled := e.rcvAutoParams.disabled
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if autoTuningDisabled {
return FindWndScale(seqnum.Size(bufSizeForScale))
}
@@ -2914,7 +2830,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
e.setRecentTimestamp(tsVal)
}
}
@@ -2924,7 +2840,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// initializes the recentTS with the value provided in synOpts.TSval.
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
- e.sendTSOk = true
+ e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2932,7 +2848,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(time.Now(), e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.TSOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
@@ -2971,7 +2887,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
return
}
if bool(v) && synOpts.SACKPermitted {
- e.sackPermitted = true
+ e.SACKPermitted = true
}
}
@@ -2985,144 +2901,70 @@ func (e *endpoint) maxOptionSize() (size int) {
return size
}
-// completeState makes a full copy of the endpoint and returns it. This is used
-// before invoking the probe. The state returned may not be fully consistent if
-// there are intervening syscalls when the state is being copied.
-func (e *endpoint) completeState() stack.TCPEndpointState {
- var s stack.TCPEndpointState
- s.SegTime = time.Now()
-
- // Copy EndpointID.
- s.ID = stack.TCPEndpointID(e.ID)
-
- // Copy endpoint rcv state.
- e.rcvListMu.Lock()
- s.RcvBufSize = e.rcvBufSize
- s.RcvBufUsed = e.rcvBufUsed
- s.RcvClosed = e.rcvClosed
- s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime
- s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied
- s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied
- s.RcvAutoParams.RTT = e.rcvAutoParams.rtt
- s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber
- s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime
- s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled
- e.rcvListMu.Unlock()
-
- // Endpoint TCP Option state.
- s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTimestamp()
- s.TSOffset = e.tsOffset
- s.SACKPermitted = e.sackPermitted
+// completeStateLocked makes a full copy of the endpoint and returns it. This is
+// used before invoking the probe.
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
+ s := stack.TCPEndpointState{
+ TCPEndpointStateInner: e.TCPEndpointStateInner,
+ ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID),
+ SegTime: time.Now(),
+ Receiver: e.rcv.TCPReceiverState,
+ Sender: e.snd.TCPSenderState,
+ }
+
+ sndBufSize := e.getSendBufferSize()
+ // Copy the send buffer atomically.
+ e.sndQueueInfo.sndQueueMu.Lock()
+ s.SndBufState = e.sndQueueInfo.TCPSndBufState
+ s.SndBufState.SndBufSize = sndBufSize
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ // Copy the receive buffer atomically.
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ // Copy the endpoint TCP Option state.
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
- // Copy endpoint send state.
- sndBufSize := e.getSendBufferSize()
- e.sndBufMu.Lock()
- s.SndBufSize = sndBufSize
- s.SndBufUsed = e.sndBufUsed
- s.SndClosed = e.sndClosed
- s.SndBufInQueue = e.sndBufInQueue
- s.PacketTooBigCount = e.packetTooBigCount
- s.SndMTU = e.sndMTU
- e.sndBufMu.Unlock()
-
- // Copy receiver state.
- s.Receiver = stack.TCPReceiverState{
- RcvNxt: e.rcv.rcvNxt,
- RcvAcc: e.rcv.rcvAcc,
- RcvWndScale: e.rcv.rcvWndScale,
- PendingBufUsed: e.rcv.pendingBufUsed,
- }
-
- // Copy sender state.
- s.Sender = stack.TCPSenderState{
- LastSendTime: e.snd.lastSendTime,
- DupAckCount: e.snd.dupAckCount,
- FastRecovery: stack.TCPFastRecoveryState{
- Active: e.snd.fr.active,
- First: e.snd.fr.first,
- Last: e.snd.fr.last,
- MaxCwnd: e.snd.fr.maxCwnd,
- HighRxt: e.snd.fr.highRxt,
- RescueRxt: e.snd.fr.rescueRxt,
- },
- SndCwnd: e.snd.sndCwnd,
- Ssthresh: e.snd.sndSsthresh,
- SndCAAckCount: e.snd.sndCAAckCount,
- Outstanding: e.snd.outstanding,
- SackedOut: e.snd.sackedOut,
- SndWnd: e.snd.sndWnd,
- SndUna: e.snd.sndUna,
- SndNxt: e.snd.sndNxt,
- RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
- RTTMeasureTime: e.snd.rttMeasureTime,
- Closed: e.snd.closed,
- RTO: e.snd.rto,
- MaxPayloadSize: e.snd.maxPayloadSize,
- SndWndScale: e.snd.sndWndScale,
- MaxSentAck: e.snd.maxSentAck,
- }
e.snd.rtt.Lock()
- s.Sender.SRTT = e.snd.rtt.srtt
- s.Sender.SRTTInited = e.snd.rtt.srttInited
+ s.Sender.RTTState = e.snd.rtt.TCPRTTState
e.snd.rtt.Unlock()
if cubic, ok := e.snd.cc.(*cubicState); ok {
- s.Sender.Cubic = stack.TCPCubicState{
- WMax: cubic.wMax,
- WLastMax: cubic.wLastMax,
- T: cubic.t,
- TimeSinceLastCongestion: time.Since(cubic.t),
- C: cubic.c,
- K: cubic.k,
- Beta: cubic.beta,
- WC: cubic.wC,
- WEst: cubic.wEst,
- }
+ s.Sender.Cubic = cubic.TCPCubicState
+ s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T)
}
- rc := &e.snd.rc
- s.Sender.RACKState = stack.TCPRACKState{
- XmitTime: rc.xmitTime,
- EndSequence: rc.endSequence,
- FACK: rc.fack,
- RTT: rc.rtt,
- Reord: rc.reorderSeen,
- DSACKSeen: rc.dsackSeen,
- ReoWnd: rc.reoWnd,
- ReoWndIncr: rc.reoWndIncr,
- ReoWndPersist: rc.reoWndPersist,
- RTTSeq: rc.rttSeq,
- }
+ s.Sender.RACKState = e.snd.rc.TCPRACKState
return s
}
func (e *endpoint) initHardwareGSO() {
- gso := &stack.GSO{}
switch e.route.NetProto() {
case header.IPv4ProtocolNumber:
- gso.Type = stack.GSOTCPv4
- gso.L3HdrLen = header.IPv4MinimumSize
+ e.gso.Type = stack.GSOTCPv4
+ e.gso.L3HdrLen = header.IPv4MinimumSize
case header.IPv6ProtocolNumber:
- gso.Type = stack.GSOTCPv6
- gso.L3HdrLen = header.IPv6MinimumSize
+ e.gso.Type = stack.GSOTCPv6
+ e.gso.L3HdrLen = header.IPv6MinimumSize
default:
panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto))
}
- gso.NeedsCsum = true
- gso.CsumOffset = header.TCPChecksumOffset
- gso.MaxSize = e.route.GSOMaxSize()
- e.gso = gso
+ e.gso.NeedsCsum = true
+ e.gso.CsumOffset = header.TCPChecksumOffset
+ e.gso.MaxSize = e.route.GSOMaxSize()
}
func (e *endpoint) initGSO() {
if e.route.HasHardwareGSOCapability() {
e.initHardwareGSO()
} else if e.route.HasSoftwareGSOCapability() {
- e.gso = &stack.GSO{
+ e.gso = stack.GSO{
MaxSize: e.route.GSOMaxSize(),
Type: stack.GSOSW,
NeedsCsum: false,
@@ -3200,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool {
e.lastOutOfWindowAckTime = now
return true
}
+
+// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP.
+func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption {
+ var ss tcpip.TCPReceiveBufferSizeRangeOption
+ if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err))
+ }
+
+ return tcpip.ReceiveBufferSizeOption{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index a53d76917..6e9777fe4 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() {
if !e.route.HasSaveRestoreCapability() {
if !e.route.HasDisconncetOkCapability() {
panic(&tcpip.ErrSaveRejection{
- Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort),
+ Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
})
}
e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
@@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if !e.workerRunning {
- // The endpoint must be in acceptedChan or has been just
+ // The endpoint must be in the accepted queue or has been just
// disconnected and closed.
break
}
@@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if e.workerRunning {
- panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID))
}
default:
panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
@@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() {
}
}
-// saveAcceptedChan is invoked by stateify.
-func (e *endpoint) saveAcceptedChan() []*endpoint {
- if e.acceptedChan == nil {
- return nil
- }
- acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case ep := <-e.acceptedChan:
- acceptedEndpoints[i] = ep
- default:
- panic("endpoint acceptedChan buffer got consumed by background context")
- }
- }
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case e.acceptedChan <- acceptedEndpoints[i]:
- default:
- panic("endpoint acceptedChan buffer got populated by background context")
- }
+// saveEndpoints is invoked by stateify.
+func (a *accepted) saveEndpoints() []*endpoint {
+ acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
+ for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
+ acceptedEndpoints[i] = e.Value.(*endpoint)
}
return acceptedEndpoints
}
-// loadAcceptedChan is invoked by stateify.
-func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
- if cap(acceptedEndpoints) > 0 {
- e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
- for _, ep := range acceptedEndpoints {
- e.acceptedChan <- ep
- }
+// loadEndpoints is invoked by stateify.
+func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+ for _, ep := range acceptedEndpoints {
+ a.endpoints.PushBack(ep)
}
}
@@ -183,7 +165,7 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
@@ -198,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) {
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
+ if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) {
+ panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
@@ -231,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.ID.RemoteAddress
+ e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
- if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning)
+ err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning)
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
@@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := cap(e.acceptedChan)
+ backlog := e.accepted.cap
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
@@ -281,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort})
+ err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
@@ -328,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime {
func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) {
e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano)
}
-
-// saveMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
- return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
-}
-
-// loadMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
- r.measureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRttMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime {
- return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()}
-}
-
-// loadRttMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) {
- r.rttMeasureTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a4667906..a3d1aa1a3 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -75,63 +75,6 @@ const (
ccCubic = "cubic"
)
-// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
-// value is protected by a mutex so that we can increment only when it's
-// guaranteed not to go above a threshold.
-type synRcvdCounter struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
- threshold uint64
-}
-
-// inc tries to increment the global number of endpoints in SYN-RCVD state. It
-// succeeds if the increment doesn't make the count go beyond the threshold, and
-// fails otherwise.
-func (s *synRcvdCounter) inc() bool {
- s.Lock()
- defer s.Unlock()
- if s.value >= s.threshold {
- return false
- }
-
- s.pending.Add(1)
- s.value++
-
- return true
-}
-
-// dec atomically decrements the global number of endpoints in SYN-RCVD
-// state. It must only be called if a previous call to inc succeeded.
-func (s *synRcvdCounter) dec() {
- s.Lock()
- defer s.Unlock()
- s.value--
- s.pending.Done()
-}
-
-// synCookiesInUse returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func (s *synRcvdCounter) synCookiesInUse() bool {
- s.Lock()
- defer s.Unlock()
- return s.value >= s.threshold
-}
-
-// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
-func (s *synRcvdCounter) SetThreshold(threshold uint64) {
- s.Lock()
- defer s.Unlock()
- s.threshold = threshold
-}
-
-// Threshold returns the current value of synRcvdCounter.Threhsold.
-func (s *synRcvdCounter) Threshold() uint64 {
- s.Lock()
- defer s.Unlock()
- return s.threshold
-}
-
type protocol struct {
stack *stack.Stack
@@ -139,6 +82,7 @@ type protocol struct {
sackEnabled bool
recovery tcpip.TCPRecovery
delayEnabled bool
+ alwaysUseSynCookies bool
sendBufferSize tcpip.TCPSendBufferSizeRangeOption
recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
@@ -150,7 +94,6 @@ type protocol struct {
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
- synRcvdCount synRcvdCounter
synRetries uint8
dispatcher dispatcher
}
@@ -216,8 +159,8 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
-func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
- route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
+ route, err := st.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -257,7 +200,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error
seq: seq,
ack: ack,
rcvWnd: 0,
- }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */)
+ }, buffer.VectorisedView{}, stack.GSO{}, nil /* PacketOwner */)
}
// SetOption implements stack.TransportProtocol.SetOption.
@@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
p.mu.Unlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(*v))
+ p.alwaysUseSynCookies = bool(*v)
p.mu.Unlock()
return nil
@@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er
p.mu.RUnlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.RLock()
- *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies)
p.mu.RUnlock()
return nil
@@ -507,12 +450,6 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
-// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
-// instance.
-func (p *protocol) SynRcvdCounter() *synRcvdCounter {
- return &p.synRcvdCount
-}
-
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
return parse.TCP(pkt)
@@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
lingerTimeout: DefaultTCPLingerTimeout,
timeWaitTimeout: DefaultTCPTimeWaitTimeout,
timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
- synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0a0d5f7a1..9e332dcf7 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -46,54 +47,16 @@ const (
//
// +stateify savable
type rackControl struct {
- // dsackSeen indicates if the connection has seen a DSACK.
- dsackSeen bool
-
- // endSequence is the ending TCP sequence number of the most recent
- // acknowledged segment.
- endSequence seqnum.Value
+ stack.TCPRACKState
// exitedRecovery indicates if the connection is exiting loss recovery.
// This flag is set if the sender is leaving the recovery after
// receiving an ACK and is reset during updating of reorder window.
exitedRecovery bool
- // fack is the highest selectively or cumulatively acknowledged
- // sequence.
- fack seqnum.Value
-
// minRTT is the estimated minimum RTT of the connection.
minRTT time.Duration
- // reorderSeen indicates if reordering has been detected on this
- // connection.
- reorderSeen bool
-
- // reoWnd is the reordering window time used for recording packet
- // transmission times. It is used to defer the moment at which RACK
- // marks a packet lost.
- reoWnd time.Duration
-
- // reoWndIncr is the multiplier applied to adjust reorder window.
- reoWndIncr uint8
-
- // reoWndPersist is the number of loss recoveries before resetting
- // reorder window.
- reoWndPersist int8
-
- // rtt is the RTT of the most recently delivered packet on the
- // connection (either cumulatively acknowledged or selectively
- // acknowledged) that was not marked invalid as a possible spurious
- // retransmission.
- rtt time.Duration
-
- // rttSeq is the SND.NXT when rtt is updated.
- rttSeq seqnum.Value
-
- // xmitTime is the latest transmission timestamp of the most recent
- // acknowledged segment.
- xmitTime time.Time `state:".(unixTime)"`
-
// tlpRxtOut indicates whether there is an unacknowledged
// TLP retransmission.
tlpRxtOut bool
@@ -108,8 +71,8 @@ type rackControl struct {
// init initializes RACK specific fields.
func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
- rc.fack = iss
- rc.reoWndIncr = 1
+ rc.FACK = iss
+ rc.ReoWndIncr = 1
rc.snd = snd
}
@@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := time.Now().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.tsOffset
+ tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
}
}
- rc.rtt = rtt
+ rc.RTT = rtt
// The sender can either track a simple global minimum of all RTT
// measurements from the connection, or a windowed min-filtered value
@@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// ending sequence number of the packet which has been acknowledged
// most recently.
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
- rc.xmitTime = seg.xmitTime
- rc.endSequence = endSeq
+ if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ rc.XmitTime = seg.xmitTime
+ rc.EndSequence = endSeq
}
}
@@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// is identified.
func (rc *rackControl) detectReorder(seg *segment) {
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.fack.LessThan(endSeq) {
- rc.fack = endSeq
+ if rc.FACK.LessThan(endSeq) {
+ rc.FACK = endSeq
return
}
- if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
- rc.reorderSeen = true
+ if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 {
+ rc.Reord = true
}
}
func (rc *rackControl) setDSACKSeen(dsackSeen bool) {
- rc.dsackSeen = dsackSeen
+ rc.DSACKSeen = dsackSeen
}
// shouldSchedulePTO dictates whether we should schedule a PTO or not.
@@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool {
// Schedule PTO only if RACK loss detection is enabled.
return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 &&
// The connection supports SACK.
- s.ep.sackPermitted &&
+ s.ep.SACKPermitted &&
// The connection is not in loss recovery.
(s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) &&
// The connection has no SACKed sequences in the SACK scoreboard.
@@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool {
func (s *sender) schedulePTO() {
pto := time.Second
s.rtt.Lock()
- if s.rtt.srttInited && s.rtt.srtt > 0 {
- pto = s.rtt.srtt * 2
- if s.outstanding == 1 {
+ if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 {
+ pto = s.rtt.TCPRTTState.SRTT * 2
+ if s.Outstanding == 1 {
pto += wcDelayedACKTimeout
}
}
@@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error {
}
var dataSent bool
- if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd {
- dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd))
+ if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd {
+ dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd))
if dataSent {
- s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize)
+ s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize)
s.writeNext = s.writeNext.Next()
}
}
@@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error {
}
if highestSeqXmit != nil {
- dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd))
+ dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd))
if dataSent {
s.rc.tlpRxtOut = true
- s.rc.tlpHighRxt = s.sndNxt
+ s.rc.tlpHighRxt = s.SndNxt
}
}
}
@@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error {
// and updates TLP state accordingly.
// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3.
func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
- if !(s.ep.sackPermitted && s.rc.tlpRxtOut) {
+ if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) {
return
}
@@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
// retransmit quickly, or when the number of DUPACKs exceeds the classic
// DUPACKthreshold.
func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
- dsackSeen := rc.dsackSeen
+ dsackSeen := rc.DSACKSeen
snd := rc.snd
// React to DSACK once per round trip.
// If SND.UNA < RACK.rtt_seq:
// RACK.dsack = false
- if snd.sndUna.LessThan(rc.rttSeq) {
+ if snd.SndUna.LessThan(rc.RTTSeq) {
dsackSeen = false
}
@@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// RACK.rtt_seq = SND.NXT
// RACK.reo_wnd_persist = 16
if dsackSeen {
- rc.reoWndIncr++
+ rc.ReoWndIncr++
dsackSeen = false
- rc.rttSeq = snd.sndNxt
- rc.reoWndPersist = tcpRACKRecoveryThreshold
+ rc.RTTSeq = snd.SndNxt
+ rc.ReoWndPersist = tcpRACKRecoveryThreshold
} else if rc.exitedRecovery {
// Else if exiting loss recovery:
// RACK.reo_wnd_persist -= 1
// If RACK.reo_wnd_persist <= 0:
// RACK.reo_wnd_incr = 1
- rc.reoWndPersist--
- if rc.reoWndPersist <= 0 {
- rc.reoWndIncr = 1
+ rc.ReoWndPersist--
+ if rc.ReoWndPersist <= 0 {
+ rc.ReoWndIncr = 1
}
rc.exitedRecovery = false
}
@@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// Else if RACK.pkts_sacked >= RACK.dupthresh:
// RACK.reo_wnd = 0
// return
- if !rc.reorderSeen {
+ if !rc.Reord {
if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery {
- rc.reoWnd = 0
+ rc.ReoWnd = 0
return
}
- if snd.sackedOut >= nDupAckThreshold {
- rc.reoWnd = 0
+ if snd.SackedOut >= nDupAckThreshold {
+ rc.ReoWnd = 0
return
}
}
@@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr
// RACK.reo_wnd = min(RACK.reo_wnd, SRTT)
snd.rtt.Lock()
- srtt := snd.rtt.srtt
+ srtt := snd.rtt.TCPRTTState.SRTT
snd.rtt.Unlock()
- rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr))
- if srtt < rc.reoWnd {
- rc.reoWnd = srtt
+ rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr))
+ if srtt < rc.ReoWnd {
+ rc.ReoWnd = srtt
}
}
@@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int {
}
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
- timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd
+ if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd
if timeRemaining <= 0 {
seg.lost = true
numLost++
@@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error {
}
fastRetransmit := false
- if !rc.snd.fr.active {
+ if !rc.snd.FastRecovery.Active {
rc.snd.cc.HandleLossDetected()
rc.snd.enterRecovery()
fastRetransmit = true
@@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) {
}
// Check the congestion window after entering recovery.
- if snd.outstanding >= snd.sndCwnd {
+ if snd.Outstanding >= snd.SndCwnd {
break
}
- if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent {
+ if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent {
break
}
dataSent = true
- snd.outstanding += snd.pCount(seg, snd.maxPayloadSize)
+ snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize)
}
snd.postXmit(dataSent, true /* shouldScheduleProbe */)
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index bc6793fc6..ee2c08cd6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// receiver holds the state necessary to receive TCP segments and turn them
@@ -29,26 +30,15 @@ import (
//
// +stateify savable
type receiver struct {
+ stack.TCPReceiverState
ep *endpoint
- rcvNxt seqnum.Value
-
- // rcvAcc is one beyond the last acceptable sequence number. That is,
- // the "largest" sequence value that the receiver has announced to the
- // its peer that it's willing to accept. This may be different than
- // rcvNxt + rcvWnd if the receive window is reduced; in that case we
- // have to reduce the window as we receive more data instead of
- // shrinking it.
- rcvAcc seqnum.Value
-
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
- // rcvWUP is the rcvNxt value at the last window update sent.
+ // rcvWUP is the RcvNxt value at the last window update sent.
rcvWUP seqnum.Value
- rcvWndScale uint8
-
// prevBufused is the snapshot of endpoint rcvBufUsed taken when we
// advertise a receive window.
prevBufUsed int
@@ -58,9 +48,6 @@ type receiver struct {
// pendingRcvdSegments is bounded by the receive buffer size of the
// endpoint.
pendingRcvdSegments segmentHeap
- // pendingBufUsed tracks the total number of bytes (including segment
- // overhead) currently queued in pendingRcvdSegments.
- pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
@@ -68,12 +55,14 @@ type receiver struct {
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
+ ep: ep,
+ TCPReceiverState: stack.TCPReceiverState{
+ RcvNxt: irs + 1,
+ RcvAcc: irs.Add(rcvWnd + 1),
+ RcvWndScale: rcvWndScale,
+ },
rcvWnd: rcvWnd,
rcvWUP: irs + 1,
- rcvWndScale: rcvWndScale,
lastRcvdAckTime: time.Now(),
}
}
@@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// r.rcvWnd could be much larger than the window size we advertised in our
// outgoing packets, we should use what we have advertised for acceptability
// test.
- scaledWindowSize := r.rcvWnd >> r.rcvWndScale
+ scaledWindowSize := r.rcvWnd >> r.RcvWndScale
if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
scaledWindowSize = math.MaxUint16
}
- advertisedWindowSize := scaledWindowSize << r.rcvWndScale
- return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
+ advertisedWindowSize := scaledWindowSize << r.RcvWndScale
+ return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize))
}
// currentWindow returns the available space in the window that was advertised
// last to our peer.
func (r *receiver) currentWindow() (curWnd seqnum.Size) {
endOfWnd := r.rcvWUP.Add(r.rcvWnd)
- if endOfWnd.LessThan(r.rcvNxt) {
- // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ if endOfWnd.LessThan(r.RcvNxt) {
+ // return 0 if r.RcvNxt is past the end of the previously advertised window.
// This can happen because we accept a large segment completely even if
// accepting it causes it to partially exceed the advertised window.
return 0
}
- return r.rcvNxt.Size(endOfWnd)
+ return r.RcvNxt.Size(endOfWnd)
}
// getSendParams returns the parameters needed by the sender when building
// segments to send.
-func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
- unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt))
bufUsed := r.ep.receiveBufferUsed()
// Grow the right edge of the window only for payloads larger than the
@@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// edge, as we are still advertising a window that we think can be serviced.
toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
- // Update rcvAcc only if new window is > previously advertised window. We
+ // Update RcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
// would end up dropping bytes that might already be in flight.
// ==================================================== sequence space.
// ^ ^ ^ ^
- // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // rcvWUP RcvNxt RcvAcc new RcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
- // If the new window moves the right edge, then update rcvAcc.
- r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
+ if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
+ // If the new window moves the right edge, then update RcvAcc.
+ r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd))
} else {
if newWnd == 0 {
// newWnd is zero but we can't advertise a zero as it would cause window
@@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
- r.rcvWUP = r.rcvNxt
+ r.rcvWUP = r.RcvNxt
r.prevBufUsed = bufUsed
- scaledWnd := r.rcvWnd >> r.rcvWndScale
+ scaledWnd := r.rcvWnd >> r.RcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
@@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// Ensure that the stashed receive window always reflects what
// is being advertised.
- r.rcvWnd = scaledWnd << r.rcvWndScale
+ r.rcvWnd = scaledWnd << r.RcvWndScale
}
- return r.rcvNxt, scaledWnd
+ return r.RcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
@@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// If the segment doesn't include the seqnum we're expecting to
// consume now, we're missing a segment. We cannot proceed until
// we receive that segment though.
- if !r.rcvNxt.InWindow(segSeq, segLen) {
+ if !r.RcvNxt.InWindow(segSeq, segLen) {
return false
}
// Trim segment to eliminate already acknowledged data.
- if segSeq.LessThan(r.rcvNxt) {
- diff := segSeq.Size(r.rcvNxt)
+ if segSeq.LessThan(r.RcvNxt) {
+ diff := segSeq.Size(r.RcvNxt)
segLen -= diff
segSeq.UpdateForward(diff)
s.sequenceNumber.UpdateForward(diff)
@@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Move segment to ready-to-deliver list. Wakeup any waiters.
r.ep.readyToRead(s)
- } else if segSeq != r.rcvNxt {
+ } else if segSeq != r.RcvNxt {
return false
}
// Update the segment that we're expecting to consume.
- r.rcvNxt = segSeq.Add(segLen)
+ r.RcvNxt = segSeq.Add(segLen)
// In cases of a misbehaving sender which could send more than the
// advertised window, we could end up in a situation where we get a
// segment that exceeds the window advertised. Instead of partially
// accepting the segment and discarding bytes beyond the advertised
- // window, we accept the whole segment and make sure r.rcvAcc is moved
- // forward to match r.rcvNxt to indicate that the window is now closed.
+ // window, we accept the whole segment and make sure r.RcvAcc is moved
+ // forward to match r.RcvNxt to indicate that the window is now closed.
//
// In absence of this check the r.acceptable() check fails and accepts
// segments that should be dropped because rcvWnd is calculated as
- // the size of the interval (rcvNxt, rcvAcc] which becomes extremely
- // large if rcvAcc is ever less than rcvNxt.
- if r.rcvAcc.LessThan(r.rcvNxt) {
- r.rcvAcc = r.rcvNxt
+ // the size of the interval (RcvNxt, RcvAcc] which becomes extremely
+ // large if RcvAcc is ever less than RcvNxt.
+ if r.RcvAcc.LessThan(r.RcvNxt) {
+ r.RcvAcc = r.RcvNxt
}
// Trim SACK Blocks to remove any SACK information that covers
// sequence numbers that have been consumed.
- TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+ TrimSACKBlockList(&r.ep.sack, r.RcvNxt)
// Handle FIN or FIN-ACK.
if s.flagIsSet(header.TCPFlagFin) {
- r.rcvNxt++
+ r.RcvNxt++
// Send ACK immediately.
r.ep.snd.sendAck()
@@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
@@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
- r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
+ r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
// Note that slice truncation does not allow garbage collection of
@@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -323,40 +312,40 @@ func (r *receiver) updateRTT() {
// estimate the round-trip time by observing the time between when a byte
// is first acknowledged and the receipt of data that is at least one
// window beyond the sequence number that was acknowledged.
- r.ep.rcvListMu.Lock()
- if r.ep.rcvAutoParams.rttMeasureTime.IsZero() {
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() {
// New measurement.
- r.ep.rcvAutoParams.rttMeasureTime = time.Now()
- r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) {
- r.ep.rcvListMu.Unlock()
+ if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) {
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime)
+ rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
// We only store the minimum observed RTT here as this is only used in
// absence of a SRTT available from either timestamps or a sender
// measurement of RTT.
- if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt {
- r.ep.rcvAutoParams.rtt = rtt
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT {
+ r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt
}
- r.ep.rcvAutoParams.rttMeasureTime = time.Now()
- r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
}
func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) {
- r.ep.rcvListMu.Lock()
- rcvClosed := r.ep.rcvClosed || r.closed
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
// If we are in one of the shutdown states then we need to do
// additional checks before we try and process the segment.
switch state {
case StateCloseWait, StateClosing, StateLastAck:
- if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ if !s.sequenceNumber.LessThanEq(r.RcvNxt) {
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// The ESTABLISHED state processing is here where if the ACK check
// fails, we ignore the packet:
// https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591
- if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ if r.ep.snd.SndNxt.LessThan(s.ackNumber) {
r.ep.snd.maybeSendOutOfWindowAck(s)
return true, nil
}
// If we are closed for reads (either due to an
// incoming FIN or the user calling shutdown(..,
- // SHUT_RD) then any data past the rcvNxt should
+ // SHUT_RD) then any data past the RcvNxt should
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
- if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) {
return true, &tcpip.ErrConnectionAborted{}
}
if state == StateFinWait1 {
@@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// If it's a retransmission of an old data segment
// or a pure ACK then allow it.
- if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) ||
s.logicalLen() == 0 {
break
}
@@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// then the only acceptable segment is a
// FIN. Since FIN can technically also carry
// data we verify that the segment carrying a
- // FIN ends at exactly e.rcvNxt+1.
+ // FIN ends at exactly e.RcvNxt+1.
//
// From RFC793 page 25.
//
@@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// while the FIN is considered to occur after
// the last actual data octet in a segment in
// which it occurs.
- if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// end has closed and the peer is yet to send a FIN. Hence we
// compare only the payload.
segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
- if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) {
return true, nil
}
return false, nil
@@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
- r.ep.rcvListMu.Lock()
- r.pendingBufUsed += s.segMemSize()
- r.ep.rcvListMu.Unlock()
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ r.PendingBufUsed += s.segMemSize()
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
- UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt)
}
// Immediately send an ack so that the peer knows it may
@@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
segSeq := s.sequenceNumber
// Skip segment altogether if it has already been acknowledged.
- if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) &&
!r.consumeSegment(s, segSeq, segLen) {
break
}
heap.Pop(&r.pendingRcvdSegments)
- r.ep.rcvListMu.Lock()
- r.pendingBufUsed -= s.segMemSize()
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ r.PendingBufUsed -= s.segMemSize()
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
s.decRef()
}
return false, nil
@@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
- if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+ if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
return false, true
}
@@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if r.ep.sendTSOk && s.parsedOptions.TS {
- r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ if r.ep.SendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq)
}
- if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) {
// If it's a FIN-ACK then resetTimeWait and send an ACK, as it
// indicates our final ACK could have been lost.
r.ep.snd.sendAck()
@@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// carries data then just send an ACK. This is according to RFC 793,
// page 37.
//
- // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
- if segSeq != r.rcvNxt || segLen != 0 {
+ // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt.
+ if segSeq != r.RcvNxt || segLen != 0 {
r.ep.snd.sendAck()
}
return false, false
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
index ff39780a5..063552c7f 100644
--- a/pkg/tcpip/transport/tcp/reno.go
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState {
func (r *renoState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
- newcwnd := r.s.sndCwnd + packetsAcked
- if newcwnd >= r.s.sndSsthresh {
- newcwnd = r.s.sndSsthresh
- r.s.sndCAAckCount = 0
+ newcwnd := r.s.SndCwnd + packetsAcked
+ if newcwnd >= r.s.Ssthresh {
+ newcwnd = r.s.Ssthresh
+ r.s.SndCAAckCount = 0
}
- packetsAcked -= newcwnd - r.s.sndCwnd
- r.s.sndCwnd = newcwnd
+ packetsAcked -= newcwnd - r.s.SndCwnd
+ r.s.SndCwnd = newcwnd
return packetsAcked
}
@@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int {
// avoidance mode as described in RFC5681 section 3.1
func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
// Consume the packets in congestion avoidance mode.
- r.s.sndCAAckCount += packetsAcked
- if r.s.sndCAAckCount >= r.s.sndCwnd {
- r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
- r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ r.s.SndCAAckCount += packetsAcked
+ if r.s.SndCAAckCount >= r.s.SndCwnd {
+ r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd
+ r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd
}
}
// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
// page 6, eq. 4. It is called when we detect congestion in the network.
func (r *renoState) reduceSlowStartThreshold() {
- r.s.sndSsthresh = r.s.outstanding / 2
- if r.s.sndSsthresh < 2 {
- r.s.sndSsthresh = 2
+ r.s.Ssthresh = r.s.Outstanding / 2
+ if r.s.Ssthresh < 2 {
+ r.s.Ssthresh = 2
}
}
@@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() {
// were acknowledged.
// Update implements congestionControl.Update.
func (r *renoState) Update(packetsAcked int) {
- if r.s.sndCwnd < r.s.sndSsthresh {
+ if r.s.SndCwnd < r.s.Ssthresh {
packetsAcked = r.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
@@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() {
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
- r.s.sndCwnd = 1
+ r.s.SndCwnd = 1
}
// PostRecovery implements congestionControl.PostRecovery.
diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go
index 2aa708e97..d368a29fc 100644
--- a/pkg/tcpip/transport/tcp/reno_recovery.go
+++ b/pkg/tcpip/transport/tcp/reno_recovery.go
@@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
snd := rr.s
// We are in fast recovery mode. Ignore the ack if it's out of range.
- if !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ if !ack.InRange(snd.SndUna, snd.SndNxt+1) {
return
}
// Don't count this as a duplicate if it is carrying data or
// updating the window.
- if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window {
+ if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window {
return
}
// Inflate the congestion window if we're getting duplicate acks
// for the packet we retransmitted.
- if !fastRetransmit && ack == snd.fr.first {
+ if !fastRetransmit && ack == snd.FastRecovery.First {
// We received a dup, inflate the congestion window by 1 packet
// if we're not at the max yet. Only inflate the window if
// regular FastRecovery is in use, RFC6675 does not require
// inflating cwnd on duplicate ACKs.
- if snd.sndCwnd < snd.fr.maxCwnd {
- snd.sndCwnd++
+ if snd.SndCwnd < snd.FastRecovery.MaxCwnd {
+ snd.SndCwnd++
}
return
}
@@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
// back onto the wire.
//
// N.B. The retransmit timer will be reset by the caller.
- snd.fr.first = ack
- snd.dupAckCount = 0
+ snd.FastRecovery.First = ack
+ snd.DupAckCount = 0
snd.resendSegment()
}
diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go
index 9d406b0bc..cd860b5e8 100644
--- a/pkg/tcpip/transport/tcp/sack_recovery.go
+++ b/pkg/tcpip/transport/tcp/sack_recovery.go
@@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
}
nextSegHint := snd.writeList.Front()
- for snd.outstanding < snd.sndCwnd {
+ for snd.Outstanding < snd.SndCwnd {
var nextSeg *segment
var rescueRtx bool
nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint)
if nextSeg == nil {
return dataSent
}
- if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) {
// New data being sent.
// Step C.3 described below is handled by
@@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
return dataSent
}
dataSent = true
- snd.outstanding++
+ snd.Outstanding++
snd.writeNext = nextSeg.Next()
continue
}
@@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// "The estimate of the amount of data outstanding in the network
// must be updated by incrementing pipe by the number of octets
// transmitted in (C.1)."
- snd.outstanding++
+ snd.Outstanding++
dataSent = true
snd.sendSegment(nextSeg)
@@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// We do the last part of rule (4) of NextSeg here to update
// RescueRxt as until this point we don't know if we are going
// to use the rescue transmission.
- snd.fr.rescueRxt = snd.fr.last
+ snd.FastRecovery.RescueRxt = snd.FastRecovery.Last
} else {
// RFC 6675, Step C.2
//
@@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// HighData, HighRxt MUST be set to the highest sequence
// number of the retransmitted segment unless NextSeg ()
// rule (4) was invoked for this retransmission."
- snd.fr.highRxt = segEnd - 1
+ snd.FastRecovery.HighRxt = segEnd - 1
}
}
return dataSent
@@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
}
// We are in fast recovery mode. Ignore the ack if it's out of range.
- if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) {
return
}
// RFC 6675 recovery algorithm step C 1-5.
- end := snd.sndUna.Add(snd.sndWnd)
- dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end)
+ end := snd.SndUna.Add(snd.SndWnd)
+ dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end)
snd.postXmit(dataSent, true /* shouldScheduleProbe */)
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 8edd6775b..c28641be3 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
-
- verifyChecksum := true
if skipChecksumValidation {
s.csumValid = true
- verifyChecksum = false
- }
- if verifyChecksum {
+ } else {
s.csum = s.hdr.Checksum()
- xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
- xsum = s.hdr.CalculateChecksum(xsum)
- xsum = header.ChecksumVV(s.data, xsum)
- s.csumValid = xsum == 0xffff
+ payloadChecksum := header.ChecksumVV(s.data, 0)
+ payloadLength := uint16(s.data.Size())
+ s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength)
}
-
s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
s.ackNumber = seqnum.Value(s.hdr.AckNumber())
s.flags = s.hdr.Flags()
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 54545a1b1..d0d1b0b8a 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool {
func (q *segmentQueue) enqueue(s *segment) bool {
// q.ep.receiveBufferParams() must be called without holding q.mu to
// avoid lock order inversion.
- bufSz := q.ep.receiveBufferSize()
+ bufSz := q.ep.ops.GetReceiveBufferSize()
used := q.ep.receiveMemUsed()
q.mu.Lock()
// Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
// is currently full).
- allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+ allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen
if allow {
q.list.PushBack(s)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index faca35892..2b32cb7b2 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -85,56 +86,12 @@ type lossRecovery interface {
//
// +stateify savable
type sender struct {
+ stack.TCPSenderState
ep *endpoint
- // lastSendTime is the timestamp when the last packet was sent.
- lastSendTime time.Time `state:".(unixTime)"`
-
- // dupAckCount is the number of duplicated acks received. It is used for
- // fast retransmit.
- dupAckCount int
-
- // fr holds state related to fast recovery.
- fr fastRecovery
-
// lr is the loss recovery algorithm used by the sender.
lr lossRecovery
- // sndCwnd is the congestion window, in packets.
- sndCwnd int
-
- // sndSsthresh is the threshold between slow start and congestion
- // avoidance.
- sndSsthresh int
-
- // sndCAAckCount is the number of packets acknowledged during congestion
- // avoidance. When enough packets have been ack'd (typically cwnd
- // packets), the congestion window is incremented by one.
- sndCAAckCount int
-
- // outstanding is the number of outstanding packets, that is, packets
- // that have been sent but not yet acknowledged.
- outstanding int
-
- // sackedOut is the number of packets which are selectively acked.
- sackedOut int
-
- // sndWnd is the send window size.
- sndWnd seqnum.Size
-
- // sndUna is the next unacknowledged sequence number.
- sndUna seqnum.Value
-
- // sndNxt is the sequence number of the next segment to be sent.
- sndNxt seqnum.Value
-
- // rttMeasureSeqNum is the sequence number being used for the latest RTT
- // measurement.
- rttMeasureSeqNum seqnum.Value
-
- // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
- rttMeasureTime time.Time `state:".(unixTime)"`
-
// firstRetransmittedSegXmitTime is the original transmit time of
// the first segment that was retransmitted due to RTO expiration.
firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
@@ -147,17 +104,15 @@ type sender struct {
// window probes.
unackZeroWindowProbes uint32 `state:"nosave"`
- closed bool
writeNext *segment
writeList segmentList
resendTimer timer `state:"nosave"`
resendWaker sleep.Waker `state:"nosave"`
- // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
- // "round-trip time variation" and "retransmit timeout", as defined in
+ // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed
+ // round-trip time", and "round-trip time variation", as defined in
// section 2 of RFC 6298.
rtt rtt
- rto time.Duration
// minRTO is the minimum permitted value for sender.rto.
minRTO time.Duration
@@ -168,20 +123,9 @@ type sender struct {
// maxRetries is the maximum permitted retransmissions.
maxRetries uint32
- // maxPayloadSize is the maximum size of the payload of a given segment.
- // It is initialized on demand.
- maxPayloadSize int
-
// gso is set if generic segmentation offload is enabled.
gso bool
- // sndWndScale is the number of bits to shift left when reading the send
- // window size from a segment.
- sndWndScale uint8
-
- // maxSentAck is the maxium acknowledgement actually sent.
- maxSentAck seqnum.Value
-
// state is the current state of congestion control for this endpoint.
state tcpip.CongestionControlState
@@ -209,41 +153,7 @@ type sender struct {
type rtt struct {
sync.Mutex `state:"nosave"`
- srtt time.Duration
- rttvar time.Duration
- srttInited bool
-}
-
-// fastRecovery holds information related to fast recovery from a packet loss.
-//
-// +stateify savable
-type fastRecovery struct {
- // active whether the endpoint is in fast recovery. The following fields
- // are only meaningful when active is true.
- active bool
-
- // first and last represent the inclusive sequence number range being
- // recovered.
- first seqnum.Value
- last seqnum.Value
-
- // maxCwnd is the maximum value the congestion window may be inflated to
- // due to duplicate acks. This exists to avoid attacks where the
- // receiver intentionally sends duplicate acks to artificially inflate
- // the sender's cwnd.
- maxCwnd int
-
- // highRxt is the highest sequence number which has been retransmitted
- // during the current loss recovery phase.
- // See: RFC 6675 Section 2 for details.
- highRxt seqnum.Value
-
- // rescueRxt is the highest sequence number which has been
- // optimistically retransmitted to prevent stalling of the ACK clock
- // when there is loss at the end of the window and no new data is
- // available for transmission.
- // See: RFC 6675 Section 2 for details.
- rescueRxt seqnum.Value
+ stack.TCPRTTState
}
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
@@ -253,22 +163,24 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
maxPayloadSize := int(mss) - ep.maxOptionSize()
s := &sender{
- ep: ep,
- sndWnd: sndWnd,
- sndUna: iss + 1,
- sndNxt: iss + 1,
- rto: 1 * time.Second,
- rttMeasureSeqNum: iss + 1,
- lastSendTime: time.Now(),
- maxPayloadSize: maxPayloadSize,
- maxSentAck: irs + 1,
- fr: fastRecovery{
- // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
- last: iss,
- highRxt: iss,
- rescueRxt: iss,
+ ep: ep,
+ TCPSenderState: stack.TCPSenderState{
+ SndWnd: sndWnd,
+ SndUna: iss + 1,
+ SndNxt: iss + 1,
+ RTTMeasureSeqNum: iss + 1,
+ LastSendTime: time.Now(),
+ MaxPayloadSize: maxPayloadSize,
+ MaxSentAck: irs + 1,
+ FastRecovery: stack.TCPFastRecoveryState{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ Last: iss,
+ HighRxt: iss,
+ RescueRxt: iss,
+ },
+ RTO: 1 * time.Second,
},
- gso: ep.gso != nil,
+ gso: ep.gso.Type != stack.GSONone,
}
if s.gso {
@@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
- s.sndWndScale = uint8(sndWndScale)
+ s.SndWndScale = uint8(sndWndScale)
}
s.resendTimer.init(&s.resendWaker)
@@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// Initialize SACK Scoreboard after updating max payload size as we use
// the maxPayloadSize as the smss when determining if a segment is lost
// etc.
- s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss)
// Get Stack wide config.
var minRTO tcpip.TCPMinRTOOption
@@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
// their initial values.
func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
- s.sndCwnd = InitialCwnd
+ s.SndCwnd = InitialCwnd
// Set sndSsthresh to the maximum int value, which depends on the
// platform.
- s.sndSsthresh = int(^uint(0) >> 1)
+ s.Ssthresh = int(^uint(0) >> 1)
switch congestionControlName {
case ccCubic:
@@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon
// initLossRecovery initiates the loss recovery algorithm for the sender.
func (s *sender) initLossRecovery() lossRecovery {
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
return newSACKRecovery(s)
}
return newRenoRecovery(s)
@@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m -= s.ep.maxOptionSize()
// We don't adjust up for now.
- if m >= s.maxPayloadSize {
+ if m >= s.MaxPayloadSize {
return
}
@@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
- oldMSS := s.maxPayloadSize
- s.maxPayloadSize = m
+ oldMSS := s.MaxPayloadSize
+ s.MaxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
}
@@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// maxPayloadSize.
s.ep.scoreboard.smss = uint16(m)
- s.outstanding -= count
- if s.outstanding < 0 {
- s.outstanding = 0
+ s.Outstanding -= count
+ if s.Outstanding < 0 {
+ s.Outstanding = 0
}
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
@@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
nextSeg = seg
}
- if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
// Update sackedOut for new maximum payload size.
- s.sackedOut -= s.pCount(seg, oldMSS)
- s.sackedOut += s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut -= s.pCount(seg, oldMSS)
+ s.SackedOut += s.pCount(seg, s.MaxPayloadSize)
}
}
@@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// sendAck sends an ACK segment.
func (s *sender) sendAck() {
- s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt)
}
// updateRTO updates the retransmit timeout when a new roud-trip time is
// available. This is done in accordance with section 2 of RFC 6298.
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
- if !s.rtt.srttInited {
- s.rtt.rttvar = rtt / 2
- s.rtt.srtt = rtt
- s.rtt.srttInited = true
+ if !s.rtt.TCPRTTState.SRTTInited {
+ s.rtt.TCPRTTState.RTTVar = rtt / 2
+ s.rtt.TCPRTTState.SRTT = rtt
+ s.rtt.TCPRTTState.SRTTInited = true
} else {
- diff := s.rtt.srtt - rtt
+ diff := s.rtt.TCPRTTState.SRTT - rtt
if diff < 0 {
diff = -diff
}
- // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when
// no timestamps are available.
- if !s.ep.sendTSOk {
- s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
- s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ if !s.ep.SendTSOk {
+ s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4
+ s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8
} else {
// When we are taking RTT measurements of every ACK then
// we need to use a modified method as specified in
// https://tools.ietf.org/html/rfc7323#appendix-G
- if s.outstanding == 0 {
+ if s.Outstanding == 0 {
s.rtt.Unlock()
return
}
@@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) {
// terms of packets and not bytes. This is similar to
// how linux also does cwnd and inflight. In practice
// this approximation works as expected.
- expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+ expectedSamples := math.Ceil(float64(s.Outstanding) / 2)
// alpha & beta values are the original values as recommended in
// https://tools.ietf.org/html/rfc6298#section-2.3.
@@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) {
alphaPrime := alpha / expectedSamples
betaPrime := beta / expectedSamples
- rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
- srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
- s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
- s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second))
}
}
- s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar
s.rtt.Unlock()
- if s.rto < s.minRTO {
- s.rto = s.minRTO
+ if s.RTO < s.minRTO {
+ s.RTO = s.minRTO
}
}
@@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) {
func (s *sender) resendSegment() {
// Don't use any segments we already sent to measure RTT as they may
// have been affected by packets being lost.
- s.rttMeasureSeqNum = s.sndNxt
+ s.RTTMeasureSeqNum = s.SndNxt
// Resend the segment.
if seg := s.writeList.Front(); seg != nil {
- if seg.data.Size() > s.maxPayloadSize {
- s.splitSeg(seg, s.maxPayloadSize)
+ if seg.data.Size() > s.MaxPayloadSize {
+ s.splitSeg(seg, s.MaxPayloadSize)
}
// See: RFC 6675 section 5 Step 4.3
//
// To prevent retransmission, set both the HighRXT and RescueRXT
// to the highest sequence number in the retransmitted segment.
- s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
- s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
s.ep.stats.SendErrors.FastRetransmit.Increment()
@@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool {
// Set new timeout. The timer will be restarted by the call to sendData
// below.
- s.rto *= 2
+ s.RTO *= 2
// Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5
- if s.rto > s.maxRTO {
- s.rto = s.maxRTO
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
}
// Cap RTO to remaining time.
- if s.rto > remaining {
- s.rto = remaining
+ if s.RTO > remaining {
+ s.RTO = remaining
}
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
@@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool {
// After a retransmit timeout, record the highest sequence number
// transmitted in the variable recover, and exit the fast recovery
// procedure if applicable.
- s.fr.last = s.sndNxt - 1
+ s.FastRecovery.Last = s.SndNxt - 1
- if s.fr.active {
+ if s.FastRecovery.Active {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
@@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool {
//
// We'll keep on transmitting (or retransmitting) as we get acks for
// the data we transmit.
- s.outstanding = 0
+ s.Outstanding = 0
// Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
//
@@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) {
// window space.
// ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
// ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
- if seg.data.Size() > s.maxPayloadSize {
+ if seg.data.Size() > s.MaxPayloadSize {
seg.flags ^= header.TCPFlagPsh
}
@@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// transmitted (i.e. either it has no assigned sequence number
// or if it does have one, it's >= the next sequence number
// to be sent [i.e. >= s.sndNxt]).
- if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) {
+ if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) {
hint = nil
break
}
@@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// (1.a) S2 is greater than HighRxt
// (1.b) S2 is less than highest octect covered by
// any received SACK.
- if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
// NextSeg():
// (1.c) IsLost(S2) returns true.
if s.ep.scoreboard.IsLost(segSeq) {
@@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// unSACKed sequence number SHOULD be returned, and
// RescueRxt set to RecoveryPoint. HighRxt MUST NOT
// be updated.
- if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) {
if s4 != nil {
if s4.sequenceNumber.LessThan(segSeq) {
s4 = seg
@@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// previously unsent data starting with sequence number
// HighData+1 MUST be returned."
for seg := s.writeNext; seg != nil; seg = seg.Next() {
- if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) {
continue
}
// We do not split the segment here to <= smss as it has
@@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(s.sndNxt.Size(end))
+ available := int(s.SndNxt.Size(end))
if available > limit {
available = limit
}
@@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
+ if s.Outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
+ if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// Assign flags. We don't do it above so that we can merge
// additional data if Nagle holds the segment.
- seg.sequenceNumber = s.sndNxt
+ seg.sequenceNumber = s.SndNxt
seg.flags = header.TCPFlagAck | header.TCPFlagPsh
}
@@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// the segment right here if there are no pending segments. If
// there are pending segments, segment transmits are deferred to
// the retransmit timer handler.
- if s.sndUna != s.sndNxt {
+ if s.SndUna != s.SndNxt {
switch {
case available >= seg.data.Size():
// OK to send, the whole segments fits in the
// receiver's advertised window.
- case available >= s.maxPayloadSize:
+ case available >= s.MaxPayloadSize:
// OK to send, at least 1 MSS sized segment fits
// in the receiver's advertised window.
default:
@@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// If GSO is not in use then cap available to
// maxPayloadSize. When GSO is in use the gVisor GSO logic or
// the host GSO logic will cap the segment to the correct size.
- if s.ep.gso == nil && available > s.maxPayloadSize {
- available = s.maxPayloadSize
+ if s.ep.gso.Type == stack.GSONone && available > s.MaxPayloadSize {
+ available = s.MaxPayloadSize
}
if seg.data.Size() > available {
@@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// Update sndNxt if we actually sent new data (as opposed to
// retransmitting some previously sent data).
- if s.sndNxt.LessThan(segEnd) {
- s.sndNxt = segEnd
+ if s.SndNxt.LessThan(segEnd) {
+ s.SndNxt = segEnd
}
return true
@@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() {
s.unackZeroWindowProbes++
// Send a zero window probe with sequence number pointing to
// the last acknowledged byte.
- s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win)
+ s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win)
// Rearm the timer to continue probing.
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
func (s *sender) enableZeroWindowProbing() {
@@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() {
if s.firstRetransmittedSegXmitTime.IsZero() {
s.firstRetransmittedSegXmitTime = time.Now()
}
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
func (s *sender) disableZeroWindowProbing() {
@@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
// If the sender has advertized zero receive window and we have
// data to be sent out, start zero window probing to query the
// the remote for it's receive window size.
- if s.writeNext != nil && s.sndWnd == 0 {
+ if s.writeNext != nil && s.SndWnd == 0 {
s.enableZeroWindowProbing()
}
// If we have no more pending data, start the keepalive timer.
- if s.sndUna == s.sndNxt {
+ if s.SndUna == s.SndNxt {
s.ep.resetKeepaliveTimer(false)
} else {
// Enable timers if we have pending data.
@@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
s.schedulePTO()
} else if !s.resendTimer.enabled() {
s.probeTimer.disable()
- if s.outstanding > 0 {
+ if s.Outstanding > 0 {
// Enable the resend timer if it's not enabled yet and there is
// outstanding data.
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
}
}
@@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
- limit := s.maxPayloadSize
+ limit := s.MaxPayloadSize
if s.gso {
limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
}
- end := s.sndUna.Add(s.sndWnd)
+ end := s.SndUna.Add(s.SndWnd)
// Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
- if s.sndCwnd > InitialCwnd {
- s.sndCwnd = InitialCwnd
+ if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO {
+ if s.SndCwnd > InitialCwnd {
+ s.SndCwnd = InitialCwnd
}
}
var dataSent bool
- for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() {
+ cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize
if cwndLimit < limit {
limit = cwndLimit
}
- if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
// Move writeNext along so that we don't try and scan data that
// has already been SACKED.
s.writeNext = seg.Next()
@@ -1036,7 +948,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg, s.maxPayloadSize)
+ s.Outstanding += s.pCount(seg, s.MaxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1044,21 +956,21 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
- s.fr.active = true
+ s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
// See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
- s.sndCwnd = s.sndSsthresh + 3
- s.sackedOut = 0
- s.dupAckCount = 0
- s.fr.first = s.sndUna
- s.fr.last = s.sndNxt - 1
- s.fr.maxCwnd = s.sndCwnd + s.outstanding
- s.fr.highRxt = s.sndUna
- s.fr.rescueRxt = s.sndUna
- if s.ep.sackPermitted {
+ s.SndCwnd = s.Ssthresh + 3
+ s.SackedOut = 0
+ s.DupAckCount = 0
+ s.FastRecovery.First = s.SndUna
+ s.FastRecovery.Last = s.SndNxt - 1
+ s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
+ s.FastRecovery.HighRxt = s.SndUna
+ s.FastRecovery.RescueRxt = s.SndUna
+ if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
// Set TLPRxtOut to false according to
@@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() {
}
func (s *sender) leaveRecovery() {
- s.fr.active = false
- s.fr.maxCwnd = 0
- s.dupAckCount = 0
+ s.FastRecovery.Active = false
+ s.FastRecovery.MaxCwnd = 0
+ s.DupAckCount = 0
// Deflate cwnd. It had been artificially inflated when new dups arrived.
- s.sndCwnd = s.sndSsthresh
+ s.SndCwnd = s.Ssthresh
s.cc.PostRecovery()
}
@@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
func (s *sender) SetPipe() {
// If SACK isn't permitted or it is permitted but recovery is not active
// then ignore pipe calculations.
- if !s.ep.sackPermitted || !s.fr.active {
+ if !s.ep.SACKPermitted || !s.FastRecovery.Active {
return
}
pipe := 0
@@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() {
// After initializing pipe to zero, the following steps are
// taken for each octet 'S1' in the sequence space between
// HighACK and HighData that has not been SACKed:
- if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ if !s1.sequenceNumber.LessThan(s.SndNxt) {
break
}
if s.ep.scoreboard.IsSACKED(sb) {
@@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() {
}
// SetPipe():
// (b) If S1 <= HighRxt, Pipe is incremented by 1.
- if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) {
pipe++
}
}
}
- s.outstanding = pipe
+ s.Outstanding = pipe
}
// shouldEnterRecovery returns true if the sender should enter fast recovery
// based on dupAck count and sack scoreboard.
// See RFC 6675 section 5.
func (s *sender) shouldEnterRecovery() bool {
- return s.dupAckCount >= nDupAckThreshold ||
- (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna))
+ return s.DupAckCount >= nDupAckThreshold ||
+ (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna))
}
// detectLoss is called when an ack is received and returns whether a loss is
@@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// If RACK is enabled and there is no reordering we should honor the
// three duplicate ACK rule to enter recovery.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4
- if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- if s.rc.reorderSeen {
+ if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.rc.Reord {
return false
}
}
if !s.isDupAck(seg) {
- s.dupAckCount = 0
+ s.DupAckCount = 0
return false
}
- s.dupAckCount++
+ s.DupAckCount++
// Do not enter fast recovery until we reach nDupAckThreshold or the
// first unacknowledged byte is considered lost as per SACK scoreboard.
if !s.shouldEnterRecovery() {
// RFC 6675 Step 3.
- s.fr.highRxt = s.sndUna - 1
+ s.FastRecovery.HighRxt = s.SndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
s.state = tcpip.Disorder
@@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// Note that we only enter recovery when at least one more byte of data
// beyond s.fr.last (the highest byte that was outstanding when fast
// retransmit was last entered) is acked.
- if !s.fr.last.LessThan(seg.ackNumber - 1) {
- s.dupAckCount = 0
+ if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) {
+ s.DupAckCount = 0
return false
}
s.cc.HandleLossDetected()
@@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool {
// can leverage the SACK information to determine when an incoming ACK is a
// "duplicate" (e.g., if the ACK contains previously unknown SACK
// information).
- if s.ep.sackPermitted && !seg.hasNewSACKInfo {
+ if s.ep.SACKPermitted && !seg.hasNewSACKInfo {
return false
}
// (a) The receiver of the ACK has outstanding data.
- return s.sndUna != s.sndNxt &&
+ return s.SndUna != s.SndNxt &&
// (b) The incoming acknowledgment carries no data.
seg.logicalLen() == 0 &&
// (c) The SYN and FIN bits are both off.
!seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) &&
// (d) the ACK number is equal to the greatest acknowledgment received on
// the given connection (TCP.UNA from RFC793).
- seg.ackNumber == s.sndUna &&
+ seg.ackNumber == s.SndUna &&
// (e) the advertised window in the incoming acknowledgment equals the
// advertised window in the last incoming acknowledgment.
- s.sndWnd == seg.window
+ s.SndWnd == seg.window
}
// Iterate the writeList and update RACK for each segment which is newly acked
@@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
seg.acked = true
- s.sackedOut += s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut += s.pCount(seg, s.MaxPayloadSize)
}
seg = seg.Next()
}
@@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool {
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
- s.updateRTO(time.Now().Sub(s.rttMeasureTime))
- s.rttMeasureSeqNum = s.sndNxt
+ if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.RTTMeasureTime))
+ s.RTTMeasureSeqNum = s.SndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
+ if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
@@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
rcvdSeg.hasNewSACKInfo = true
}
@@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ack := rcvdSeg.ackNumber
fastRetransmit := false
// Do not leave fast recovery, if the ACK is out of range.
- if s.fr.active {
+ if s.FastRecovery.Active {
// Leave fast recovery if it acknowledges all the data covered by
// this fast recovery session.
- if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) {
+ if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) {
s.leaveRecovery()
}
} else {
@@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Stash away the current window size.
- s.sndWnd = rcvdSeg.window
+ s.SndWnd = rcvdSeg.window
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
if s.zeroWindowProbing && rcvdSeg.window > 0 &&
- (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
+ (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) {
s.disableZeroWindowProbing()
}
// On receiving the ACK for the zero window probe, account for it and
// skip trying to send any segment as we are still probing for
// receive window to become non-zero.
- if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna {
+ if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna {
s.unackZeroWindowProbes--
return
}
// Ignore ack if it doesn't acknowledge any new data.
- if (ack - 1).InRange(s.sndUna, s.sndNxt) {
- s.dupAckCount = 0
+ if (ack - 1).InRange(s.SndUna, s.SndNxt) {
+ s.DupAckCount = 0
// See : https://tools.ietf.org/html/rfc1323#section-3.3.
// Specifically we should only update the RTO using TSEcr if the
@@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
+ if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
@@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// When an ack is received we must rearm the timer.
// RFC 6298 5.3
s.probeTimer.disable()
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
// Remove all acknowledged data from the write list.
- acked := s.sndUna.Size(ack)
- s.sndUna = ack
+ acked := s.SndUna.Size(ack)
+ s.SndUna = ack
// The remote ACK-ing at least 1 byte is an indication that we have a
// full-duplex connection to the remote as the only way we will receive an
@@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
ackLeft := acked
- originalOutstanding := s.outstanding
+ originalOutstanding := s.Outstanding
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg, s.maxPayloadSize)
+ prevCount := s.pCount(seg, s.MaxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
+ s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize)
break
}
@@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
}
@@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
- if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.Outstanding -= s.pCount(seg, s.MaxPayloadSize)
} else {
- s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut -= s.pCount(seg, s.MaxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
@@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.ep.updateSndBufferUsage(int(acked))
// Clear SACK information for all acked data.
- s.ep.scoreboard.Delete(s.sndUna)
+ s.ep.scoreboard.Delete(s.SndUna)
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
- if !s.fr.active {
- s.cc.Update(originalOutstanding - s.outstanding)
- if s.fr.last.LessThan(s.sndUna) {
+ if !s.FastRecovery.Active {
+ s.cc.Update(originalOutstanding - s.Outstanding)
+ if s.FastRecovery.Last.LessThan(s.SndUna) {
s.state = tcpip.Open
// Update RACK when we are exiting fast or RTO
// recovery as described in the RFC
@@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
- if s.outstanding < 0 {
- s.outstanding = 0
+ if s.Outstanding < 0 {
+ s.Outstanding = 0
}
s.SetPipe()
// If all outstanding data was acknowledged the disable the timer.
// RFC 6298 Rule 5.3
- if s.sndUna == s.sndNxt {
- s.outstanding = 0
+ if s.SndUna == s.SndNxt {
+ s.Outstanding = 0
// Reset firstRetransmittedSegXmitTime to the zero value.
s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
@@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
- if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
// Update RACK reorder window.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// * Upon receiving an ACK:
@@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// After the reorder window is calculated, detect any loss by checking
// if the time elapsed after the segments are sent is greater than the
// reorder window.
- if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active {
+ if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active {
// If any segment is marked as lost by
// RACK, enter recovery and retransmit
// the lost segments.
@@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
fastRetransmit = true
}
- if s.fr.active {
+ if s.FastRecovery.Active {
s.rc.DoRecovery(nil, fastRetransmit)
}
}
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 {
+ if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 {
s.lr.DoRecovery(rcvdSeg, fastRetransmit)
// When SACK is enabled data sending is governed by steps in
// RFC 6675 Section 5 recovery steps A-C.
// See: https://tools.ietf.org/html/rfc6675#section-5.
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
return
}
}
@@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
- if s.sndCwnd < s.sndSsthresh {
+ if s.SndCwnd < s.Ssthresh {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
}
@@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// then use the conservative timer described in RFC6675 Section 6.0,
// otherwise follow the standard time described in RFC6298 Section 5.1.
if err != nil && seg.data.Size() != 0 {
- if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted {
- s.resendTimer.enable(s.rto)
+ if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted {
+ s.resendTimer.enable(s.RTO)
} else {
if !s.resendTimer.enabled() {
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
}
}
@@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
- s.lastSendTime = time.Now()
- if seq == s.rttMeasureSeqNum {
- s.rttMeasureTime = s.lastSendTime
+ s.LastSendTime = time.Now()
+ if seq == s.RTTMeasureSeqNum {
+ s.RTTMeasureTime = s.LastSendTime
}
rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
// Remember the max sent ack.
- s.maxSentAck = rcvNxt
+ s.MaxSentAck = rcvNxt
return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index ba41cff6d..2f805d8ce 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -24,26 +24,6 @@ type unixTime struct {
nano int64
}
-// saveLastSendTime is invoked by stateify.
-func (s *sender) saveLastSendTime() unixTime {
- return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
-}
-
-// loadLastSendTime is invoked by stateify.
-func (s *sender) loadLastSendTime(unix unixTime) {
- s.lastSendTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRttMeasureTime is invoked by stateify.
-func (s *sender) saveRttMeasureTime() unixTime {
- return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
-}
-
-// loadRttMeasureTime is invoked by stateify.
-func (s *sender) loadRttMeasureTime(unix unixTime) {
- s.rttMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
// afterLoad is invoked by stateify.
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 5cdd5b588..c58361bc1 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,6 +33,7 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ latency = 5 * time.Millisecond
)
func setStackRACKPermitted(t *testing.T, c *context.Context) {
@@ -182,6 +183,9 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
+ // This delay is added to increase RTT as low RTT can cause TLP
+ // before sending ACK.
+ time.Sleep(latency)
}
return data
@@ -479,7 +483,7 @@ func TestRACKOnePacketTailLoss(t *testing.T) {
}{
// #3 was retransmitted as TLP.
{tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
{tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
// RTO should not have fired.
{tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
@@ -852,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan
return
}
- if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT {
- probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT)
+ if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT {
+ probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT)
return
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 81f800cad..20c9761f2 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9c23469f2..9916182e3 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
"gvisor.dev/gvisor/pkg/test/testutil"
@@ -86,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
}
for w.N != 0 {
_, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for receive to be notified.
select {
case <-notifyRead:
@@ -129,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) {
{
err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -144,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) {
// and stats updates.
{
err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrAborted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{})
+ if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -158,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) {
}
}
+// Test for ICMP error handling without completing handshake.
+func TestConnectICMPError(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventHUp)
+ defer wq.EventUnregister(&waitEntry)
+
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
+ }
+ }
+
+ syn := c.GetPacket()
+ checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn)))
+
+ wep := ep.(interface {
+ StopWork()
+ ResumeWork()
+ LastErrorLocked() tcpip.Error
+ })
+
+ // Stop the protocol loop, ensure that the ICMP error is processed and
+ // the last ICMP error is read before the loop is resumed. This sanity
+ // tests the handshake completion logic on ICMP errors.
+ wep.StopWork()
+
+ c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU)
+
+ for {
+ if err := wep.LastErrorLocked(); err != nil {
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d)
+ }
+ break
+ }
+ time.Sleep(time.Millisecond)
+ }
+
+ wep.ResumeWork()
+
+ <-notifyCh
+
+ // The stack would have unregistered the endpoint because of the ICMP error.
+ // Expect a RST for any subsequent packets sent to the endpoint.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1,
+ AckNum: c.IRS + 1,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
func TestConnectIncrementActiveConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -201,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{})
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -392,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -929,17 +1000,14 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
}
// Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
- }
+ rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize()
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
{
err := c.EP.Connect(connectAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d)
}
}
@@ -955,11 +1023,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
// when completing the handshake for a new TCP connection from a TCP
// listening socket. It should be present in the sent TCP SYN-ACK segment.
func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
- const (
- nonSynCookieAccepts = 2
- totalAccepts = 4
- mtu = 5000
- )
+ const mtu = 5000
ips := []struct {
name string
@@ -1033,12 +1097,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
ip.createEP(c)
- // Set the SynRcvd threshold to force a syn cookie based accept to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
}
@@ -1048,13 +1106,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
t.Fatalf("Bind(%+v): %s:", bindAddr, err)
}
- if err := c.EP.Listen(totalAccepts); err != nil {
- t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ backlog := 5
+ // Keep the number of client requests twice to the backlog
+ // such that half of the connections do not use syncookies
+ // and the other half does.
+ clientConnects := backlog * 2
+
+ if err := c.EP.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s:", backlog, err)
}
- // The first nonSynCookieAccepts packets sent will trigger a gorooutine
- // based accept. The rest will trigger a cookie based accept.
- for i := 0; i < totalAccepts; i++ {
+ for i := 0; i < clientConnects; i++ {
// Send a SYN requests.
iss := seqnum.Value(i)
srcPort := context.TestPort + uint16(i)
@@ -1297,6 +1359,98 @@ func TestListenShutdown(t *testing.T) {
))
}
+var _ waiter.EntryCallback = (callback)(nil)
+
+type callback func(*waiter.Entry, waiter.EventMask)
+
+func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
+ cb(entry, mask)
+}
+
+func TestListenerReadinessOnEvent(t *testing.T) {
+ s := stack.New(stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ {
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+ const id = 1
+ if err := s.CreateNIC(id, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
+ }
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
+ t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: header.IPv4EmptySubnet, NIC: id},
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
+ t.Fatalf("Bind(%s): %s", context.StackAddr, err)
+ }
+ const backlog = 1
+ if err := ep.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s", backlog, err)
+ }
+
+ address, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress(): %s", err)
+ }
+
+ conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer conn.Close()
+
+ events := make(chan waiter.EventMask)
+ // Scope `entry` to allow a binding of the same name below.
+ {
+ entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
+ events <- ep.Readiness(mask)
+ })}
+ wq.EventRegister(&entry, waiter.EventIn)
+ defer wq.EventUnregister(&entry)
+ }
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+ defer wq.EventUnregister(&entry)
+
+ switch err := conn.Connect(address).(type) {
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect(%#v): %v", address, err)
+ }
+
+ // Read at least one event.
+ got := <-events
+ for {
+ select {
+ case event := <-events:
+ got |= event
+ continue
+ case <-ch:
+ if want := waiter.ReadableEvents; got != want {
+ t.Errorf("observed events = %b, want %b", got, want)
+ }
+ }
+ break
+ }
+}
+
// TestListenCloseWhileConnect tests for the listening endpoint to
// drain the accept-queue when closed. This should reset all of the
// pending connections that are waiting to be accepted.
@@ -1459,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -1520,8 +1674,8 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -1993,9 +2147,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true)
// Keep the payload size < segment overhead and such that it is a multiple
// of the window scaled value. This enables the test to perform equality
@@ -2115,9 +2267,7 @@ func TestNoWindowShrinking(t *testing.T) {
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true)
data := generateRandomPayload(t, rcvBufSize)
// Send a payload of half the size of rcvBufSize.
@@ -2373,9 +2523,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -2395,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -2447,9 +2595,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -2469,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3001,8 +3147,8 @@ func TestSetTTL(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -3042,9 +3188,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3063,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3087,11 +3231,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, mtu)
defer c.Cleanup()
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -3119,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3185,9 +3327,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 3
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
@@ -3196,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -3315,8 +3455,8 @@ loop:
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
- t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
break loop
case <-time.After(1 * time.Second):
@@ -3366,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) {
var r bytes.Reader
r.Reset(make([]byte, 10))
_, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
- t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d)
}
}
@@ -4320,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) {
var buf bytes.Buffer
{
_, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
- if _, ok := err.(*tcpip.ErrClosedForReceive); !ok {
- t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{})
+ if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d)
}
}
}
@@ -4365,8 +4505,8 @@ func TestReusePort(t *testing.T) {
}
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
c.EP.Close()
@@ -4411,11 +4551,7 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
- }
-
+ s := ep.SocketOptions().GetReceiveBufferSize()
if int(s) != v {
t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
@@ -4521,10 +4657,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min/2.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
- }
-
+ ep.SocketOptions().SetReceiveBufferSize(99, true)
checkRecvBufferSize(t, ep, 200)
ep.SocketOptions().SetSendBufferSize(149, true)
@@ -4532,15 +4665,11 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
-
+ ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true)
// Values above max are capped at max and then doubled.
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
-
// Values above max are capped at max and then doubled.
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
@@ -4665,8 +4794,8 @@ func TestSelfConnect(t *testing.T) {
{
err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -4814,7 +4943,13 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("unknown address type: '%s'", candidateAddressType)
}
- start, end := s.PortRange()
+ const (
+ start = 16000
+ end = 16050
+ )
+ if err := s.SetPortRange(start, end); err != nil {
+ t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
+ }
for i := start; i <= end; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
@@ -5387,7 +5522,7 @@ func TestListenBacklogFull(t *testing.T) {
for i := 0; i < listenBacklog; i++ {
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5404,7 +5539,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5416,7 +5551,7 @@ func TestListenBacklogFull(t *testing.T) {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5445,8 +5580,8 @@ func TestListenBacklogFull(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
// non unicast IPv4 address are not accepted.
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
- multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
- otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ multicastAddr := tcpiptestutil.MustParse4("224.0.1.2")
+ otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3")
subnet := context.StackAddrWithPrefix.Subnet()
subnetBroadcastAddr := subnet.Broadcast()
@@ -5557,8 +5692,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpiptestutil.MustParse6("ff0e::101")
+ otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102")
tests := []struct {
name string
@@ -5671,15 +5806,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
// Test acceptance.
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
// second one should not get any response and is dropped as
- // the synRcvd count will be equal to backlog.
+ // the accept queue is full.
irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5701,23 +5834,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- //
- // NOTE: we did not complete the handshake for the previous one so the
- // accept backlog should be empty and there should be one connection in
- // synRcvd state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Now complete the previous connection and verify that there is a connection
- // to accept.
+ // Now complete the previous connection.
// Send ACK.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5728,13 +5845,26 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
RcvWnd: 30000,
})
- // Try to accept the connections in the backlog.
+ // Verify if that is delivered to the accept queue.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.ReadableEvents)
defer c.WQ.EventUnregister(&we)
+ <-ch
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+ // Try to accept the connections in the backlog.
newEP, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5764,11 +5894,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.TCPSynRcvdCountThresholdOption(1)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
// Create TCP endpoint.
var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -5781,9 +5906,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
t.Fatalf("Bind failed: %s", err)
}
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ // Test for SynCookies usage after filling up the backlog.
+ if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %s", err)
}
@@ -5811,7 +5935,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
defer c.WQ.EventUnregister(&we)
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5827,7 +5951,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5966,7 +6090,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
t.Fatalf("Accept failed: %s", err)
}
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.ReadableEvents)
@@ -6034,7 +6158,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
// Verify that there is only one acceptable connection at this point.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6104,7 +6228,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
// Now check that there is one acceptable connections.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6156,7 +6280,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
defer wq.EventUnregister(&we)
aep, _, err := ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6174,8 +6298,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
{
err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok {
- t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{})
+ if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" {
+ t.Errorf("Connect(...) mismatch (-want +got):\n%s", d)
}
}
// Listening endpoint remains in listen state.
@@ -6295,7 +6419,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// window increases to the full available buffer size.
for {
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
break
}
}
@@ -6426,7 +6550,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalCopied := 0
for {
res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
break
}
totalCopied += res.Count
@@ -6618,7 +6742,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6737,7 +6861,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6844,7 +6968,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6934,7 +7058,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Try to accept the connection.
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7008,7 +7132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7158,7 +7282,7 @@ func TestTCPCloseWithData(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7553,8 +7677,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
- c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
-
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
@@ -7590,8 +7713,8 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
_, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
+ if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
+ t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
}
// Send data. This should result in an acceptable endpoint.
@@ -7649,8 +7772,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
_, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
+ if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
+ t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
}
// Sleep for a little of the tcpDeferAccept timeout.
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 2949588ce..1deb1fe4d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e73f90bb0..53efecc5a 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -331,8 +331,8 @@ func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
+ if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize)
}
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
@@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) {
}
if epRcvBuf != -1 {
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */)
}
}
@@ -1216,9 +1214,9 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
if enable {
- c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.HWGSOSupported
} else {
- c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.GSONotSupported
}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 153e8c950..dd5c910ae 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 956da0e0c..f7dd50d35 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "fmt"
"io"
"sync/atomic"
@@ -89,12 +88,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList udpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -144,6 +142,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// +stateify savable
@@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- rcvBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
return e
@@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
e.mu.Lock()
e.sendTOS = uint8(v)
e.mu.Unlock()
-
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
- }
-
- if v < rs.Min {
- v = rs.Min
- }
- if v > rs.Max {
- v = rs.Max
- }
-
- e.mu.Lock()
- e.rcvBufSizeMax = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.mu.Lock()
v := int(e.ttl)
@@ -872,7 +848,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
TTL: ttl,
TOS: tos,
@@ -1255,20 +1231,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
-// On IPv4, UDP checksum is optional, and a zero value means the transmitter
-// omitted the checksum generation (RFC768).
-// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
- if !pkt.RXTransportChecksumValidated &&
- (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
- netHdr := pkt.Network()
- xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
- for _, v := range pkt.Data().Views() {
- xsum = header.Checksum(v, xsum)
- }
- return hdr.CalculateChecksum(xsum) == 0xffff
+ if pkt.RXTransportChecksumValidated {
+ return true
+ }
+
+ // On IPv4, UDP checksum is optional, and a zero value means the transmitter
+ // omitted the checksum generation, as per RFC 768:
+ //
+ // An all zero transmitted checksum value means that the transmitter
+ // generated no checksum (for debugging or for higher level protocols that
+ // don't care).
+ //
+ // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
+ //
+ // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
+ // checksum is not optional.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
+ return true
}
- return true
+
+ netHdr := pkt.Network()
+ payloadChecksum := pkt.Data().AsRange().Checksum()
+ return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
}
// HandlePacket is called by the stack when new packets arrive to this transport
@@ -1284,7 +1269,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
if !verifyChecksum(hdr, pkt) {
- // Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
e.stats.ReceiveErrors.ChecksumErrors.Increment()
return
@@ -1302,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -1436,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 21a6aa460..4aba68b21 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
+
e.mu.Lock()
defer e.mu.Unlock()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 77ca70a04..dc2e3f493 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -34,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -2364,7 +2365,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
ipv4Subnet := ipv4Addr.Subnet()
ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4Gateway := testutil.MustParse4("192.168.1.1")
ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
Address: "\xc0\xa8\x01\x3a",
PrefixLen: 31,
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 7f983a0b3..366f068e3 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -36,8 +36,8 @@ go_test(
tags = [
# Requires docker and runsc to be configured before test runs.
# Also requires the test to be run as root.
- "manual",
"local",
+ "manual",
],
visibility = ["//:sandbox"],
)
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 41fcf4978..06152a444 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error {
select {
case err := <-errChan:
return err
- case <-statusChan:
+ case res := <-statusChan:
+ if res.StatusCode != 0 {
+ var msg string
+ if res.Error != nil {
+ msg = res.Error.Message
+ }
+ return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg)
+ }
return nil
}
}
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
index 4855a52fc..12fe98b16 100644
--- a/pkg/test/dockerutil/profile.go
+++ b/pkg/test/dockerutil/profile.go
@@ -82,10 +82,15 @@ func (p *profile) createProcess(c *Container) error {
}
// The root directory of this container's runtime.
- root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
+ rootDir := fmt.Sprintf("/var/run/docker/runtime-%s/moby", c.runtime)
+ if _, err := os.Stat(rootDir); os.IsNotExist(err) {
+ // In docker v20+, due to https://github.com/moby/moby/issues/42345 the
+ // rootDir seems to always be the following.
+ rootDir = "/var/run/docker/runtime-runc/moby"
+ }
- // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`.
- args := []string{root, "debug"}
+ // Format is `runsc --root=rootDir debug --profile-*=file --duration=24h containerID`.
+ args := []string{fmt.Sprintf("--root=%s", rootDir), "debug"}
for _, profileArg := range p.Types {
outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg))
args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath))