summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/ioctl.go20
-rw-r--r--pkg/abi/linux/membarrier.go34
-rw-r--r--pkg/abi/linux/netfilter.go21
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go13
-rw-r--r--pkg/abi/linux/seccomp.go19
-rw-r--r--pkg/abi/linux/signalfd.go4
-rw-r--r--pkg/marshal/BUILD4
-rw-r--r--pkg/marshal/primitive/BUILD1
-rw-r--r--pkg/marshal/primitive/primitive.go102
-rw-r--r--pkg/merkletree/merkletree.go256
-rw-r--r--pkg/merkletree/merkletree_test.go209
-rw-r--r--pkg/seccomp/BUILD2
-rw-r--r--pkg/seccomp/seccomp_test.go246
-rw-r--r--pkg/sentry/arch/BUILD1
-rw-r--r--pkg/sentry/arch/arch_aarch64.go2
-rw-r--r--pkg/sentry/arch/registers.proto1
-rw-r--r--pkg/sentry/arch/signal_amd64.go27
-rw-r--r--pkg/sentry/arch/signal_arm64.go30
-rw-r--r--pkg/sentry/arch/stack.go179
-rw-r--r--pkg/sentry/arch/stack_unsafe.go69
-rw-r--r--pkg/sentry/control/proc.go4
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go14
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/net_tun.go52
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go17
-rw-r--r--pkg/sentry/fs/proc/task.go44
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go2
-rw-r--r--pkg/sentry/fs/user/path.go1
-rw-r--r--pkg/sentry/fs/user/user.go1
-rw-r--r--pkg/sentry/fsbridge/vfs.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go91
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go5
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go3
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go14
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go32
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go46
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go12
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go19
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go66
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go12
-rw-r--r--pkg/sentry/fsimpl/host/host.go6
-rw-r--r--pkg/sentry/fsimpl/host/socket.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD12
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go192
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go235
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go173
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go87
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go40
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go3
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go5
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go22
-rw-r--r--pkg/sentry/fsimpl/proc/task.go87
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go67
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go131
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go40
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go56
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go22
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go118
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go12
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go5
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go8
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go51
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go37
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs_test.go1
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD21
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go216
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go238
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go491
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/membarrier.go90
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/kcov.go40
-rw-r--r--pkg/sentry/kernel/kernel.go129
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go12
-rw-r--r--pkg/sentry/kernel/seccomp.go46
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/kernel/task.go4
-rw-r--r--pkg/sentry/kernel/task_context.go6
-rw-r--r--pkg/sentry/kernel/task_signals.go6
-rw-r--r--pkg/sentry/kernel/threads.go7
-rw-r--r--pkg/sentry/kernel/vdso.go19
-rw-r--r--pkg/sentry/loader/loader.go10
-rw-r--r--pkg/sentry/loader/vdso.go6
-rw-r--r--pkg/sentry/memmap/memmap.go4
-rw-r--r--pkg/sentry/mm/mm.go14
-rw-r--r--pkg/sentry/mm/pma.go24
-rw-r--r--pkg/sentry/mm/syscalls.go25
-rw-r--r--pkg/sentry/mm/vma.go10
-rw-r--r--pkg/sentry/platform/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/BUILD12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go31
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go6
-rw-r--r--pkg/sentry/platform/kvm/context.go5
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go13
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go11
-rw-r--r--pkg/sentry/platform/kvm/kvm.go13
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go21
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go29
-rw-r--r--pkg/sentry/platform/kvm/machine.go52
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go186
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go115
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go26
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go26
-rw-r--r--pkg/sentry/platform/platform.go51
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go1
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go38
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.s204
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s19
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/kernel.go22
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go64
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s47
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go11
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go4
-rw-r--r--pkg/sentry/platform/ring0/x86.go40
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go7
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go72
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go23
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go23
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go30
-rw-r--r--pkg/sentry/socket/netfilter/targets.go472
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go7
-rw-r--r--pkg/sentry/socket/netstack/netstack.go54
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go7
-rw-r--r--pkg/sentry/socket/unix/BUILD16
-rw-r--r--pkg/sentry/socket/unix/unix.go36
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go20
-rw-r--r--pkg/sentry/strace/strace.go5
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go1
-rw-r--r--pkg/sentry/syscalls/linux/sys_membarrier.go103
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go6
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/anonfs.go5
-rw-r--r--pkg/sentry/vfs/file_description.go24
-rw-r--r--pkg/sentry/vfs/filesystem_type.go3
-rw-r--r--pkg/sentry/vfs/mount.go11
-rw-r--r--pkg/sentry/vfs/vfs.go10
-rw-r--r--pkg/tcpip/buffer/view.go18
-rw-r--r--pkg/tcpip/checker/checker.go202
-rw-r--r--pkg/tcpip/faketime/faketime.go20
-rw-r--r--pkg/tcpip/header/eth.go16
-rw-r--r--pkg/tcpip/header/eth_test.go47
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/icmpv6.go21
-rw-r--r--pkg/tcpip/header/ipv4.go71
-rw-r--r--pkg/tcpip/header/ipv6.go14
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go113
-rw-r--r--pkg/tcpip/header/ipversion_test.go2
-rw-r--r--pkg/tcpip/header/parse/parse.go2
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe.go124
-rw-r--r--pkg/tcpip/link/tun/device.go42
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go107
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD5
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go129
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go252
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go12
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go512
-rw-r--r--pkg/tcpip/network/ipv4/BUILD4
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go150
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go507
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go583
-rw-r--r--pkg/tcpip/network/ipv6/BUILD5
-rw-r--r--pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go (renamed from pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go)2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go242
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go295
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1093
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go687
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go (renamed from pkg/tcpip/stack/ndp.go)576
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go95
-rw-r--r--pkg/tcpip/network/testutil/BUILD1
-rw-r--r--pkg/tcpip/stack/BUILD10
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go753
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go77
-rw-r--r--pkg/tcpip/stack/conntrack.go46
-rw-r--r--pkg/tcpip/stack/forwarding_test.go (renamed from pkg/tcpip/stack/forwarder_test.go)61
-rw-r--r--pkg/tcpip/stack/iptables.go52
-rw-r--r--pkg/tcpip/stack/iptables_targets.go147
-rw-r--r--pkg/tcpip/stack/iptables_types.go43
-rw-r--r--pkg/tcpip/stack/ndp_test.go815
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go15
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go66
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go21
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go25
-rw-r--r--pkg/tcpip/stack/nic.go1424
-rw-r--r--pkg/tcpip/stack/nic_test.go152
-rw-r--r--pkg/tcpip/stack/packet_buffer.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go (renamed from pkg/tcpip/stack/forwarder.go)60
-rw-r--r--pkg/tcpip/stack/registration.go305
-rw-r--r--pkg/tcpip/stack/route.go150
-rw-r--r--pkg/tcpip/stack/stack.go379
-rw-r--r--pkg/tcpip/stack/stack_test.go286
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go25
-rw-r--r--pkg/tcpip/tcpip.go11
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go378
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go219
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go64
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go18
-rw-r--r--pkg/tcpip/transport/tcp/connect.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/tcp/rack.go54
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go47
-rw-r--r--pkg/tcpip/transport/tcp/segment.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go62
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go75
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go28
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go9
-rw-r--r--pkg/tcpip/transport/udp/protocol.go13
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go97
-rw-r--r--pkg/test/testutil/testutil.go2
-rw-r--r--pkg/usermem/usermem.go46
-rw-r--r--pkg/usermem/usermem_test.go18
276 files changed, 13311 insertions, 5672 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index cdcaa8c73..4a26e28de 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -38,6 +38,7 @@ go_library(
"ipc.go",
"limits.go",
"linux.go",
+ "membarrier.go",
"mm.go",
"netdevice.go",
"netfilter.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index dc9ac7e7c..7df02dd6d 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,9 +121,27 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
- FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_MEASURE_VERITY = 3221513862
)
+// DigestMetadata is a helper struct for VerityDigest.
+//
+// +marshal
+type DigestMetadata struct {
+ DigestAlgorithm uint16
+ DigestSize uint16
+}
+
+// SizeOfDigestMetadata is the size of struct DigestMetadata.
+const SizeOfDigestMetadata = 4
+
+// VerityDigest is struct from uapi/linux/fsverity.h.
+type VerityDigest struct {
+ Metadata DigestMetadata
+ Digest []byte
+}
+
// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
func IOC(dir, typ, nr, size uint32) uint32 {
return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
diff --git a/pkg/abi/linux/membarrier.go b/pkg/abi/linux/membarrier.go
new file mode 100644
index 000000000..4f6021a1d
--- /dev/null
+++ b/pkg/abi/linux/membarrier.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// membarrier(2) commands, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_QUERY = 0
+ MEMBARRIER_CMD_GLOBAL = (1 << 0)
+ MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1)
+ MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 5)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 6)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ = (1 << 7)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ = (1 << 8)
+)
+
+// membarrier(2) flags, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_FLAG_CPU = (1 << 0)
+)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 1c5b34711..b521144d9 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -265,6 +265,18 @@ type KernelXTEntryMatch struct {
Data []byte
}
+// XTGetRevision corresponds to xt_get_revision in
+// include/uapi/linux/netfilter/x_tables.h
+//
+// +marshal
+type XTGetRevision struct {
+ Name ExtensionName
+ Revision uint8
+}
+
+// SizeOfXTGetRevision is the size of an XTGetRevision.
+const SizeOfXTGetRevision = 30
+
// XTEntryTarget holds a target for a rule. For example, it can specify that
// packets matching the rule should DROP, ACCEPT, or use an extension target.
// iptables-extension(8) has a list of possible targets.
@@ -285,6 +297,13 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
+// KernelXTEntryTarget is identical to XTEntryTarget, but contains a
+// variable-length Data field.
+type KernelXTEntryTarget struct {
+ XTEntryTarget
+ Data []byte
+}
+
// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
@@ -510,6 +529,8 @@ type IPTReplace struct {
const SizeOfIPTReplace = 96
// ExtensionName holds the name of a netfilter extension.
+//
+// +marshal
type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index a137940b6..6d31eb5e3 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -321,3 +321,16 @@ const (
// Enable all flags.
IP6T_INV_MASK = 0x7F
)
+
+// NFNATRange corresponds to struct nf_nat_range in
+// include/uapi/linux/netfilter/nf_nat.h.
+type NFNATRange struct {
+ Flags uint32
+ MinAddr Inet6Addr
+ MaxAddr Inet6Addr
+ MinProto uint16 // Network byte order.
+ MaxProto uint16 // Network byte order.
+}
+
+// SizeOfNFNATRange is the size of NFNATRange.
+const SizeOfNFNATRange = 40
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index b07cafe12..5be3f10f9 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -83,3 +83,22 @@ type SockFprog struct {
pad [6]byte
Filter *BPFInstruction
}
+
+// SeccompData is equivalent to struct seccomp_data, which contains the data
+// passed to seccomp-bpf filters.
+//
+// +marshal
+type SeccompData struct {
+ // Nr is the system call number.
+ Nr int32
+
+ // Arch is an AUDIT_ARCH_* value indicating the system call convention.
+ Arch uint32
+
+ // InstructionPointer is the value of the instruction pointer at the time
+ // of the system call.
+ InstructionPointer uint64
+
+ // Args contains the first 6 system call arguments.
+ Args [6]uint64
+}
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
index 85fad9956..468c6a387 100644
--- a/pkg/abi/linux/signalfd.go
+++ b/pkg/abi/linux/signalfd.go
@@ -23,6 +23,8 @@ const (
)
// SignalfdSiginfo is the siginfo encoding for signalfds.
+//
+// +marshal
type SignalfdSiginfo struct {
Signo uint32
Errno int32
@@ -41,5 +43,5 @@ type SignalfdSiginfo struct {
STime uint64
Addr uint64
AddrLSB uint16
- _ [48]uint8
+ _ [48]uint8 `marshal:"unaligned"`
}
diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD
index 4aec98218..aac0161fa 100644
--- a/pkg/marshal/BUILD
+++ b/pkg/marshal/BUILD
@@ -11,7 +11,5 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = [
- "//pkg/usermem",
- ],
+ deps = ["//pkg/usermem"],
)
diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
index 06741e6d1..d77a11c79 100644
--- a/pkg/marshal/primitive/BUILD
+++ b/pkg/marshal/primitive/BUILD
@@ -12,6 +12,7 @@ go_library(
"//:sandbox",
],
deps = [
+ "//pkg/context",
"//pkg/marshal",
"//pkg/usermem",
],
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index dfdae5d60..4b342de6b 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -19,6 +19,7 @@ package primitive
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -126,6 +127,46 @@ var _ marshal.Marshallable = (*ByteSlice)(nil)
// Below, we define some convenience functions for marshalling primitive types
// using the newtypes above, without requiring superfluous casts.
+// 8-bit integers
+
+// CopyInt8In is a convenient wrapper for copying in an int8 from the task's
+// memory.
+func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) {
+ var buf Int8
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int8(buf)
+ return n, nil
+}
+
+// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's
+// memory.
+func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) {
+ srcP := Int8(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's
+// memory.
+func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) {
+ var buf Uint8
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint8(buf)
+ return n, nil
+}
+
+// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's
+// memory.
+func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) {
+ srcP := Uint8(src)
+ return srcP.CopyOut(cc, addr)
+}
+
// 16-bit integers
// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
@@ -245,3 +286,64 @@ func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int,
srcP := Uint64(src)
return srcP.CopyOut(cc, addr)
}
+
+// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the
+// task's memory.
+func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) {
+ var buf ByteSlice
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = []byte(buf)
+ return n, nil
+}
+
+// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the
+// task's memory.
+func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) {
+ srcP := ByteSlice(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyStringIn is a convenient wrapper for copying in a string from the
+// task's memory.
+func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) {
+ var buf ByteSlice
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = string(buf)
+ return n, nil
+}
+
+// CopyStringOut is a convenient wrapper for copying out a string to the task's
+// memory.
+func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) {
+ srcP := ByteSlice(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// IOCopyContext wraps an object implementing usermem.IO to implement
+// marshal.CopyContext.
+type IOCopyContext struct {
+ Ctx context.Context
+ IO usermem.IO
+ Opts usermem.IOOpts
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+ return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+ return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
+}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 4b4f9bd52..d8227b8bd 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -41,7 +41,7 @@ type Layout struct {
blockSize int64
// digestSize is the size of a generated hash.
digestSize int64
- // levelOffset contains the offset of the begnning of each level in
+ // levelOffset contains the offset of the beginning of each level in
// bytes. The number of levels in the tree is the length of the slice.
// The leaf nodes (level 0) contain hashes of blocks of the input data.
// Each level N contains hashes of the blocks in level N-1. The highest
@@ -123,48 +123,73 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
return layout.levelOffset[level] + index*layout.blockSize
}
-// Generate constructs a Merkle tree for the contents of data. The output is
-// written to treeWriter. The treeReader should be able to read the tree after
-// it has been written. That is, treeWriter and treeReader should point to the
-// same underlying data but have separate cursors.
-// Generate will modify the cursor for data, but always restores it to its
-// original position upon exit. The cursor for tree is modified and not
-// restored.
-func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) {
- layout := InitLayout(dataSize, dataAndTreeInSameFile)
+// VerityDescriptor is a struct that is serialized and hashed to get a file's
+// root hash, which contains the root hash of the raw content and the file's
+// meatadata.
+type VerityDescriptor struct {
+ Name string
+ Mode uint32
+ UID uint32
+ GID uint32
+ RootHash []byte
+}
- numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize
+func (d *VerityDescriptor) String() string {
+ return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+}
+
+// verify generates a hash from d, and compares it with expected.
+func (d *VerityDescriptor) verify(expected []byte) error {
+ h := sha256.Sum256([]byte(d.String()))
+ if !bytes.Equal(h[:], expected) {
+ return fmt.Errorf("unexpected root hash")
+ }
+ return nil
+}
+
+// GenerateParams contains the parameters used to generate a Merkle tree.
+type GenerateParams struct {
+ // File is a reader of the file to be hashed.
+ File io.ReaderAt
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // TreeReader is a reader for the Merkle tree.
+ TreeReader io.ReaderAt
+ // TreeWriter is a writer for the Merkle tree.
+ TreeWriter io.Writer
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
+// Generate constructs a Merkle tree for the contents of params.File. The
+// output is written to params.TreeWriter.
+//
+// Generate returns a hash of a VerityDescriptor, which contains the file
+// metadata and the hash from file content.
+func Generate(params *GenerateParams) ([]byte, error) {
+ layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
+
+ numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
// If the data is in the same file as the tree, zero pad the last data
// block.
- bytesInLastBlock := dataSize % layout.blockSize
- if dataAndTreeInSameFile && bytesInLastBlock != 0 {
+ bytesInLastBlock := params.Size % layout.blockSize
+ if params.DataAndTreeInSameFile && bytesInLastBlock != 0 {
zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock)
- if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF {
- return nil, err
- }
- if _, err := treeWriter.Write(zeroBuf); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf); err != nil {
return nil, err
}
}
- // Store the current offset, so we can set it back once verification
- // finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
- if err != nil {
- return nil, err
- }
- defer data.Seek(origOffset, io.SeekStart)
-
- // Read from the beginning of both data and treeReader.
- if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF {
- return nil, err
- }
-
- if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF {
- return nil, err
- }
-
var root []byte
for level := 0; level < layout.numLevels(); level++ {
for i := int64(0); i < numBlocks; i++ {
@@ -176,11 +201,11 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
if level == 0 {
// Read data block from the target file since level 0 includes hashes
// of blocks in the input data.
- n, err = data.Read(buf)
+ n, err = params.File.ReadAt(buf, i*layout.blockSize)
} else {
// Read data block from the tree file since levels higher than 0 are
// hashing the lower level hashes.
- n, err = treeReader.Read(buf)
+ n, err = params.TreeReader.ReadAt(buf, layout.blockOffset(level-1, i))
}
// err is populated as long as the bytes read is smaller than the buffer
@@ -200,7 +225,7 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
}
// Write the generated hash to the end of the tree file.
- if _, err = treeWriter.Write(digest[:]); err != nil {
+ if _, err = params.TreeWriter.Write(digest[:]); err != nil {
return nil, err
}
}
@@ -208,46 +233,95 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
// remaining of the last block. But no need to do so for root.
if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 {
zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize)
- if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf[:]); err != nil {
return nil, err
}
}
numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock()
}
- return root, nil
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ ret := sha256.Sum256([]byte(descriptor.String()))
+ return ret[:], nil
+}
+
+// VerifyParams contains the params used to verify a portion of a file against
+// a Merkle tree.
+type VerifyParams struct {
+ // Out will be filled with verified data.
+ Out io.Writer
+ // File is a handler on the file to be verified.
+ File io.ReaderAt
+ // tree is a handler on the Merkle tree used to verify file.
+ Tree io.ReaderAt
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // ReadOffset is the offset of the data range to be verified.
+ ReadOffset int64
+ // ReadSize is the size of the data range to be verified.
+ ReadSize int64
+ // Expected is a trusted hash for the file. It is compared with the
+ // calculated root hash to verify the content.
+ Expected []byte
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
+// verifyMetadata verifies the metadata by hashing a descriptor that contains
+// the metadata and compare the generated hash with expected.
+//
+// For verifyMetadata, params.data is not needed. It only accesses params.tree
+// for the raw root hash.
+func verifyMetadata(params *VerifyParams, layout *Layout) error {
+ root := make([]byte, layout.digestSize)
+ if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ return descriptor.verify(params.Expected)
}
// Verify verifies the content read from data with offset. The content is
// verified against tree. If content spans across multiple blocks, each block is
// verified. Verification fails if the hash of the data does not match the tree
-// at any level, or if the final root hash does not match expectedRoot.
-// Once the data is verified, it will be written using w.
-// Verify will modify the cursor for data, but always restores it to its
-// original position upon exit. The cursor for tree is modified and not
-// restored.
-func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) {
- if readSize <= 0 {
- return 0, fmt.Errorf("Unexpected read size: %d", readSize)
+// at any level, or if the final root hash does not match expected.
+// Once the data is verified, it will be written using params.Out.
+//
+// Verify checks for both target file content and metadata. If readSize is 0,
+// only metadata is checked.
+func Verify(params *VerifyParams) (int64, error) {
+ if params.ReadSize < 0 {
+ return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
+ }
+ layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ if params.ReadSize == 0 {
+ return 0, verifyMetadata(params, &layout)
}
- layout := InitLayout(int64(dataSize), dataAndTreeInSameFile)
// Calculate the index of blocks that includes the target range in input
// data.
- firstDataBlock := readOffset / layout.blockSize
- lastDataBlock := (readOffset + readSize - 1) / layout.blockSize
-
- // Store the current offset, so we can set it back once verification
- // finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
- if err != nil {
- return 0, fmt.Errorf("Find current data offset failed: %v", err)
- }
- defer data.Seek(origOffset, io.SeekStart)
-
- // Move to the first block that contains target data.
- if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
- return 0, fmt.Errorf("Seek to datablock start failed: %v", err)
- }
+ firstDataBlock := params.ReadOffset / layout.blockSize
+ lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize
buf := make([]byte, layout.blockSize)
var readErr error
@@ -255,7 +329,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
for i := firstDataBlock; i <= lastDataBlock; i++ {
// Read a block that includes all or part of target range in
// input data.
- bytesRead, err := data.Read(buf)
+ bytesRead, err := params.File.ReadAt(buf, i*layout.blockSize)
readErr = err
// If at the end of input data and all previous blocks are
// verified, return the verified input data and EOF.
@@ -263,7 +337,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
break
}
if readErr != nil && readErr != io.EOF {
- return 0, fmt.Errorf("Read from data failed: %v", err)
+ return 0, fmt.Errorf("read from data failed: %w", err)
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
@@ -274,22 +348,29 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
buf[j] = 0
}
}
- if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ }
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil {
return 0, err
}
+
// startOff is the beginning of the read range within the
// current data block. Note that for all blocks other than the
// first, startOff should be 0.
startOff := int64(0)
if i == firstDataBlock {
- startOff = readOffset % layout.blockSize
+ startOff = params.ReadOffset % layout.blockSize
}
// endOff is the end of the read range within the current data
// block. Note that for all blocks other than the last, endOff
// should be the block size.
endOff := layout.blockSize
if i == lastDataBlock {
- endOff = (readOffset+readSize-1)%layout.blockSize + 1
+ endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1
}
// If the provided size exceeds the end of input data, we should
// only copy the parts in buf that's part of input data.
@@ -299,7 +380,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
if endOff > int64(bytesRead) {
endOff = int64(bytesRead)
}
- n, err := w.Write(buf[startOff:endOff])
+ n, err := params.Out.Write(buf[startOff:endOff])
if err != nil {
return total, err
}
@@ -313,9 +394,8 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
// original data. The block is verified through each level of the tree. It
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
-// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to
-// maintain the cursor if intended.
-func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error {
+// expected.
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -332,41 +412,27 @@ func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex
// Read a block in previous level that contains the
// hash we just generated, and generate a next level
// hash from it.
- if _, err := tree.Seek(layout.blockOffset(level-1, blockIndex), io.SeekStart); err != nil {
- return err
- }
- if _, err := tree.Read(treeBlock); err != nil {
+ if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil {
return err
}
digestArray := sha256.Sum256(treeBlock)
digest = digestArray[:]
}
- // Move to stored hash for the current block, read the digest
- // and store in expectedDigest.
- if _, err := tree.Seek(layout.digestOffset(level, blockIndex), io.SeekStart); err != nil {
- return err
- }
- if _, err := tree.Read(expectedDigest); err != nil {
+ // Read the digest for the current block and store in
+ // expectedDigest.
+ if _, err := tree.ReadAt(expectedDigest, layout.digestOffset(level, blockIndex)); err != nil {
return err
}
if !bytes.Equal(digest, expectedDigest) {
- return fmt.Errorf("Verification failed")
- }
-
- // If this is the root layer, no need to generate next level
- // hash.
- if level == layout.rootLevel() {
- break
+ return fmt.Errorf("verification failed")
}
blockIndex = blockIndex / layout.hashesPerBlock()
}
- // Verification for the tree succeeded. Now compare the root hash in the
- // tree with expectedRoot.
- if !bytes.Equal(digest[:], expectedRoot) {
- return fmt.Errorf("Verification failed")
- }
- return nil
+ // Verification for the tree succeeded. Now hash the descriptor with
+ // the root hash and compare it with expected.
+ descriptor.RootHash = digest
+ return descriptor.verify(expected)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index daaca759a..e1350ebda 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -84,6 +84,13 @@ func TestLayout(t *testing.T) {
}
}
+const (
+ defaultName = "merkle_test"
+ defaultMode = 0644
+ defaultUID = 0
+ defaultGID = 0
+)
+
// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike
// bytes.Buffer, it keeps the whole buffer during read so that it can be reused.
type bytesReadWriter struct {
@@ -99,58 +106,36 @@ func (brw *bytesReadWriter) Write(p []byte) (int, error) {
return len(p), nil
}
-func (brw *bytesReadWriter) Read(p []byte) (int, error) {
- if brw.readPos >= len(brw.bytes) {
- return 0, io.EOF
- }
- bytesRead := copy(p, brw.bytes[brw.readPos:])
- brw.readPos += bytesRead
- if bytesRead < len(p) {
+func (brw *bytesReadWriter) ReadAt(p []byte, off int64) (int, error) {
+ bytesRead := copy(p, brw.bytes[off:])
+ if bytesRead == 0 {
return bytesRead, io.EOF
}
return bytesRead, nil
}
-func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) {
- off := offset
- if whence == io.SeekCurrent {
- off += int64(brw.readPos)
- }
- if whence == io.SeekEnd {
- off += int64(len(brw.bytes))
- }
- if off < 0 {
- panic("seek with negative offset")
- }
- if off >= int64(len(brw.bytes)) {
- return 0, io.EOF
- }
- brw.readPos = int(off)
- return off, nil
-}
-
func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
data []byte
- expectedRoot []byte
+ expectedHash []byte
}{
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
},
{
data: []byte{'a'},
- expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90},
+ expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
},
}
@@ -158,22 +143,31 @@ func TestGenerate(t *testing.T) {
t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ params := GenerateParams{
+ Size: int64(len(tc.data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(tc.data)
- root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ params.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ params.File = &bytesReadWriter{
bytes: tc.data,
- }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ hash, err := Generate(&params)
if err != nil {
t.Fatalf("Got err: %v, want nil", err)
}
- if !bytes.Equal(root, tc.expectedRoot) {
- t.Errorf("Got root: %v, want %v", root, tc.expectedRoot)
+ if !bytes.Equal(hash, tc.expectedHash) {
+ t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash)
}
}
})
@@ -194,6 +188,10 @@ func TestVerify(t *testing.T) {
// modified byte falls in verification range, Verify should
// fail, otherwise Verify should still succeed.
modifyByte int64
+ modifyName bool
+ modifyMode bool
+ modifyUID bool
+ modifyGID bool
shouldSucceed bool
}{
// Verify range start outside the data range should fail.
@@ -222,12 +220,48 @@ func TestVerify(t *testing.T) {
modifyByte: 0,
shouldSucceed: false,
},
- // Invalid verify range (0 size) should fail.
+ // 0 verify size should only verify metadata.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ shouldSucceed: true,
+ },
+ // Modified name should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyName: true,
+ shouldSucceed: false,
+ },
+ // Modified mode should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyMode: true,
+ shouldSucceed: false,
+ },
+ // Modified UID should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyUID: true,
+ shouldSucceed: false,
+ },
+ // Modified GID should fail verification.
{
dataSize: usermem.PageSize,
verifyStart: 0,
verifySize: 0,
modifyByte: 0,
+ modifyGID: true,
shouldSucceed: false,
},
// The test cases below use a block-aligned verify range.
@@ -316,16 +350,25 @@ func TestVerify(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */)
+ }
}
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -333,8 +376,34 @@ func TestVerify(t *testing.T) {
// Flip a bit in data and checks Verify results.
var buf bytes.Buffer
data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
+ }
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
+ }
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
+ }
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
if tc.shouldSucceed {
- n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed when expected to succeed: %v", err)
}
@@ -348,7 +417,7 @@ func TestVerify(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
} else {
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&verifyParams); err == nil {
t.Errorf("Verification succeeded when expected to fail")
}
}
@@ -368,16 +437,26 @@ func TestVerifyRandom(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -387,9 +466,24 @@ func TestVerifyRandom(t *testing.T) {
size := rand.Int63n(dataSize) + 1
var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: start,
+ ReadSize: size,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
// Checks that the random portion of data from the original data is
// verified successfully.
- n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed for correct data: %v", err)
}
@@ -406,13 +500,22 @@ func TestVerifyRandom(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
+ // Verify that modified metadata should fail verification.
buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
+
// Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
randBytePos := rand.Int63n(size)
data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
- t.Errorf("Verification succeeded for modified data")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
}
}
}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index bdef7762c..e828894b0 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -49,7 +49,7 @@ go_test(
library = ":seccomp",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bpf",
+ "//pkg/usermem",
],
)
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 23f30678d..e1444d18b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -28,17 +28,10 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-type seccompData struct {
- nr uint32
- arch uint32
- instructionPointer uint64
- args [6]uint64
-}
-
// newVictim makes a victim binary.
func newVictim() (string, error) {
f, err := ioutil.TempFile("", "victim")
@@ -58,9 +51,14 @@ func newVictim() (string, error) {
return path, nil
}
-// asInput converts a seccompData to a bpf.Input.
-func (d *seccompData) asInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+// dataAsInput converts a linux.SeccompData to a bpf.Input.
+func dataAsInput(d *linux.SeccompData) bpf.Input {
+ buf := make([]byte, d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ Order: usermem.ByteOrder,
+ }
}
func TestBasic(t *testing.T) {
@@ -69,7 +67,7 @@ func TestBasic(t *testing.T) {
desc string
// data is the input data.
- data seccompData
+ data linux.SeccompData
// want is the expected return value of the BPF program.
want linux.BPFAction
@@ -95,12 +93,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "syscall allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "syscall disallowed",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -131,22 +129,22 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1a)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (1b)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "syscall 1 matched 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "no match",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -168,42 +166,42 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (3)",
- data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (5)",
- data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed (0)",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (2)",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (4)",
- data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (6)",
- data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (100)",
- data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -223,7 +221,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arch (123)",
- data: seccompData{nr: 1, arch: 123},
+ data: linux.SeccompData{Nr: 1, Arch: 123},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -243,7 +241,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "action trap",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -268,12 +266,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -300,12 +298,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "match first rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "match 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -331,28 +329,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "argument allowed (all match)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "argument disallowed (one mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "argument disallowed (multiple mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -379,28 +377,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (one equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (all equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -429,27 +427,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -474,27 +472,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -522,27 +520,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -567,32 +565,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (both greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -620,27 +618,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -665,32 +663,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -718,27 +716,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -764,32 +762,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -816,51 +814,51 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (low order mandatory bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000001
- args: [6]uint64{0x1},
+ Args: [6]uint64{0x1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (low order optional bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000101
- args: [6]uint64{0x5},
+ Args: [6]uint64{0x5},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (lowest order bit not set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000010
- args: [6]uint64{0x2},
+ Args: [6]uint64{0x2},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second lowest order bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000011
- args: [6]uint64{0x3},
+ Args: [6]uint64{0x3},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (8th bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000001 00000000
- args: [6]uint64{0x100},
+ Args: [6]uint64{0x100},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -885,12 +883,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -906,7 +904,7 @@ func TestBasic(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for _, spec := range test.specs {
- got, err := bpf.Exec(p, spec.data.asInput())
+ got, err := bpf.Exec(p, dataAsInput(&spec.data))
if err != nil {
t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err)
}
@@ -947,8 +945,8 @@ func TestRandom(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for i := uint32(0); i < 200; i++ {
- data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH}
- got, err := bpf.Exec(p, data.asInput())
+ data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH}
+ got, err := bpf.Exec(p, dataAsInput(&data))
if err != nil {
t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i)
continue
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 99e2b3389..4af4d6e84 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -22,6 +22,7 @@ go_library(
"signal_info.go",
"signal_stack.go",
"stack.go",
+ "stack_unsafe.go",
"syscalls_amd64.go",
"syscalls_arm64.go",
],
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 0f433ee79..fd73751e7 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -154,6 +154,7 @@ func (s State) Proto() *rpb.Registers {
Sp: s.Regs.Sp,
Pc: s.Regs.Pc,
Pstate: s.Regs.Pstate,
+ Tls: s.Regs.TPIDR_EL0,
}
return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
}
@@ -232,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
"Sp": uintptr(s.Regs.Sp),
"Pc": uintptr(s.Regs.Pc),
"Pstate": uintptr(s.Regs.Pstate),
+ "Tls": uintptr(s.Regs.TPIDR_EL0),
}, nil
}
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 60c027aab..2727ba08a 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -83,6 +83,7 @@ message ARM64Registers {
uint64 sp = 32;
uint64 pc = 33;
uint64 pstate = 34;
+ uint64 tls = 35;
}
message Registers {
oneof arch {
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 6fb756f0e..72e07a988 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -17,17 +17,19 @@
package arch
import (
- "encoding/binary"
"math"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
)
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
+//
+// +marshal
type SignalContext64 struct {
R8 uint64
R9 uint64
@@ -68,6 +70,8 @@ const (
)
// UContext64 is equivalent to ucontext_t on 64-bit x86.
+//
+// +marshal
type UContext64 struct {
Flags uint64
Link uint64
@@ -172,12 +176,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// "... the value (%rsp+8) is always a multiple of 16 (...) when
// control is transferred to the function entry point." - AMD64 ABI
- ucSize := binary.Size(uc)
- if ucSize < 0 {
- // This can only happen if we've screwed up the definition of
- // UContext64.
- panic("can't get size of UContext64")
- }
+ ucSize := uc.SizeBytes()
// st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
frameSize := int(st.Arch.Width()) + ucSize + 128
frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
@@ -195,18 +194,18 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
info.FixSignalCodeForUser()
// Set up the stack frame.
- infoAddr, err := st.Push(info)
- if err != nil {
+ if _, err := info.CopyOut(st, StackBottomMagic); err != nil {
return err
}
- ucAddr, err := st.Push(uc)
- if err != nil {
+ infoAddr := st.Bottom
+ if _, err := uc.CopyOut(st, StackBottomMagic); err != nil {
return err
}
+ ucAddr := st.Bottom
if act.HasRestorer() {
// Push the restorer return address.
// Note that this doesn't need to be popped.
- if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil {
+ if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil {
return err
}
} else {
@@ -240,11 +239,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
- if _, err := st.Pop(&uc); err != nil {
+ if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
var info SignalInfo
- if _, err := st.Pop(&info); err != nil {
+ if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 642c79dda..7fde5d34e 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package arch
import (
- "encoding/binary"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +26,8 @@ import (
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
+//
+// +marshal
type SignalContext64 struct {
FaultAddr uint64
Regs [31]uint64
@@ -36,6 +39,7 @@ type SignalContext64 struct {
Reserved [3568]uint8
}
+// +marshal
type aarch64Ctx struct {
Magic uint32
Size uint32
@@ -43,6 +47,8 @@ type aarch64Ctx struct {
// FpsimdContext is equivalent to struct fpsimd_context on arm64
// (arch/arm64/include/uapi/asm/sigcontext.h).
+//
+// +marshal
type FpsimdContext struct {
Head aarch64Ctx
Fpsr uint32
@@ -51,13 +57,15 @@ type FpsimdContext struct {
}
// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+//
+// +marshal
type UContext64 struct {
Flags uint64
Link uint64
Stack SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
- _pad [(1024 - 64) / 8]byte
+ _pad [120]byte // (1024 - 64) / 8 = 120
// sigcontext must be aligned to 16-byte
_pad2 [8]byte
// last for future expansion
@@ -94,11 +102,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
},
Sigset: sigset,
}
-
- ucSize := binary.Size(uc)
- if ucSize < 0 {
- panic("can't get size of UContext64")
- }
+ ucSize := uc.SizeBytes()
// frameSize = ucSize + sizeof(siginfo).
// sizeof(siginfo) == 128.
@@ -119,14 +123,14 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
info.FixSignalCodeForUser()
// Set up the stack frame.
- infoAddr, err := st.Push(info)
- if err != nil {
+ if _, err := info.CopyOut(st, StackBottomMagic); err != nil {
return err
}
- ucAddr, err := st.Push(uc)
- if err != nil {
+ infoAddr := st.Bottom
+ if _, err := uc.CopyOut(st, StackBottomMagic); err != nil {
return err
}
+ ucAddr := st.Bottom
// Set up registers.
c.Regs.Sp = uint64(st.Bottom)
@@ -147,11 +151,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
- if _, err := st.Pop(&uc); err != nil {
+ if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
var info SignalInfo
- if _, err := st.Pop(&info); err != nil {
+ if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 1108fa0bd..5f06c751d 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -15,14 +15,16 @@
package arch
import (
- "encoding/binary"
- "fmt"
-
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
)
-// Stack is a simple wrapper around a usermem.IO and an address.
+// Stack is a simple wrapper around a usermem.IO and an address. Stack
+// implements marshal.CopyContext, and marshallable values can be pushed or
+// popped from the stack through the marshal.Marshallable interface.
+//
+// Stack is not thread-safe.
type Stack struct {
// Our arch info.
// We use this for automatic Native conversion of usermem.Addrs during
@@ -34,105 +36,60 @@ type Stack struct {
// Our current stack bottom.
Bottom usermem.Addr
-}
-// Push pushes the given values on to the stack.
-//
-// (This method supports Addrs and treats them as native types.)
-func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
- for _, v := range vals {
-
- // We convert some types to well-known serializable quanities.
- var norm interface{}
-
- // For array types, we will automatically add an appropriate
- // terminal value. This is done simply to make the interface
- // easier to use.
- var term interface{}
-
- switch v.(type) {
- case string:
- norm = []byte(v.(string))
- term = byte(0)
- case []int8, []uint8:
- norm = v
- term = byte(0)
- case []int16, []uint16:
- norm = v
- term = uint16(0)
- case []int32, []uint32:
- norm = v
- term = uint32(0)
- case []int64, []uint64:
- norm = v
- term = uint64(0)
- case []usermem.Addr:
- // Special case: simply push recursively.
- _, err := s.Push(s.Arch.Native(uintptr(0)))
- if err != nil {
- return 0, err
- }
- varr := v.([]usermem.Addr)
- for i := len(varr) - 1; i >= 0; i-- {
- _, err := s.Push(varr[i])
- if err != nil {
- return 0, err
- }
- }
- continue
- case usermem.Addr:
- norm = s.Arch.Native(uintptr(v.(usermem.Addr)))
- default:
- norm = v
- }
+ // Scratch buffer used for marshalling to avoid having to repeatedly
+ // allocate scratch memory.
+ scratchBuf []byte
+}
- if term != nil {
- _, err := s.Push(term)
- if err != nil {
- return 0, err
- }
- }
+// scratchBufLen is the default length of Stack.scratchBuf. The
+// largest structs the stack regularly serializes are arch.SignalInfo
+// and arch.UContext64. We'll set the default size as the larger of
+// the two, arch.UContext64.
+var scratchBufLen = (*UContext64)(nil).SizeBytes()
- c := binary.Size(norm)
- if c < 0 {
- return 0, fmt.Errorf("bad binary.Size for %T", v)
- }
- n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
- if err != nil || c != n {
- return 0, err
- }
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (s *Stack) CopyScratchBuffer(size int) []byte {
+ if len(s.scratchBuf) < size {
+ s.scratchBuf = make([]byte, size)
+ }
+ return s.scratchBuf[:size]
+}
+// StackBottomMagic is the special address callers must past to all stack
+// marshalling operations to cause the src/dst address to be computed based on
+// the current end of the stack.
+const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1)
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes
+// computes an appropriate address based on the current end of the
+// stack. Callers use the sentinel address StackBottomMagic to marshal methods
+// to indicate this.
+func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) {
+ if sentinel != StackBottomMagic {
+ panic("Attempted to copy out to stack with absolute address")
+ }
+ c := len(b)
+ n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{})
+ if err == nil && n == c {
s.Bottom -= usermem.Addr(n)
}
-
- return s.Bottom, nil
+ return n, err
}
-// Pop pops the given values off the stack.
-//
-// (This method supports Addrs and treats them as native types.)
-func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
- for _, v := range vals {
-
- vaddr, isVaddr := v.(*usermem.Addr)
-
- var n int
- var err error
- if isVaddr {
- value := s.Arch.Native(uintptr(0))
- n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
- *vaddr = usermem.Addr(s.Arch.Value(value))
- } else {
- n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
- }
- if err != nil {
- return 0, err
- }
-
+// CopyInBytes implements marshal.CopyContext.CopyInBytes. CopyInBytes computes
+// an appropriate address based on the current end of the stack. Callers must
+// use the sentinel address StackBottomMagic to marshal methods to indicate
+// this.
+func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) {
+ if sentinel != StackBottomMagic {
+ panic("Attempted to copy in from stack with absolute address")
+ }
+ n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{})
+ if err == nil {
s.Bottom += usermem.Addr(n)
}
-
- return s.Bottom, nil
+ return n, err
}
// Align aligns the stack to the given offset.
@@ -142,6 +99,22 @@ func (s *Stack) Align(offset int) {
}
}
+// PushNullTerminatedByteSlice writes bs to the stack, followed by an extra null
+// byte at the end. On error, the contents of the stack and the bottom cursor
+// are undefined.
+func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) {
+ // Note: Stack grows up, so write the terminal null byte first.
+ nNull, err := primitive.CopyUint8Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ n, err := primitive.CopyByteSliceOut(s, StackBottomMagic, bs)
+ if err != nil {
+ return 0, err
+ }
+ return n + nNull, nil
+}
+
// StackLayout describes the location of the arguments and environment on the
// stack.
type StackLayout struct {
@@ -177,11 +150,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
l.EnvvEnd = s.Bottom
envAddrs := make([]usermem.Addr, len(env))
for i := len(env) - 1; i >= 0; i-- {
- addr, err := s.Push(env[i])
- if err != nil {
+ if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil {
return StackLayout{}, err
}
- envAddrs[i] = addr
+ envAddrs[i] = s.Bottom
}
l.EnvvStart = s.Bottom
@@ -189,11 +161,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
l.ArgvEnd = s.Bottom
argAddrs := make([]usermem.Addr, len(args))
for i := len(args) - 1; i >= 0; i-- {
- addr, err := s.Push(args[i])
- if err != nil {
+ if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil {
return StackLayout{}, err
}
- argAddrs[i] = addr
+ argAddrs[i] = s.Bottom
}
l.ArgvStart = s.Bottom
@@ -222,26 +193,26 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
auxv = append(auxv, usermem.Addr(a.Key), a.Value)
}
auxv = append(auxv, usermem.Addr(0))
- _, err := s.Push(auxv)
+ _, err := s.pushAddrSliceAndTerminator(auxv)
if err != nil {
return StackLayout{}, err
}
// Push environment.
- _, err = s.Push(envAddrs)
+ _, err = s.pushAddrSliceAndTerminator(envAddrs)
if err != nil {
return StackLayout{}, err
}
// Push args.
- _, err = s.Push(argAddrs)
+ _, err = s.pushAddrSliceAndTerminator(argAddrs)
if err != nil {
return StackLayout{}, err
}
// Push arg count.
- _, err = s.Push(usermem.Addr(len(args)))
- if err != nil {
+ lenP := s.Arch.Native(uintptr(len(args)))
+ if _, err = lenP.CopyOut(s, StackBottomMagic); err != nil {
return StackLayout{}, err
}
diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go
new file mode 100644
index 000000000..a90d297ee
--- /dev/null
+++ b/pkg/sentry/arch/stack_unsafe.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and
+// also pushes an extra null address element at the end of the slice.
+//
+// Internally, we unsafely transmute the slice type from the arch-dependent
+// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to
+// go-marshal.
+//
+// On error, the contents of the stack and the bottom cursor are undefined.
+func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) {
+ // Note: Stack grows upwards, so push the terminator first.
+ srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src))
+ switch s.Arch.Width() {
+ case 8:
+ nNull, err := primitive.CopyUint64Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ var dst []uint64
+ dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
+ dstHdr.Data = srcHdr.Data
+ dstHdr.Len = srcHdr.Len
+ dstHdr.Cap = srcHdr.Cap
+ n, err := primitive.CopyUint64SliceOut(s, StackBottomMagic, dst)
+ // Ensures src doesn't get GCed until we're done using it through dst.
+ runtime.KeepAlive(src)
+ return n + nNull, err
+ case 4:
+ nNull, err := primitive.CopyUint32Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ var dst []uint32
+ dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
+ dstHdr.Data = srcHdr.Data
+ dstHdr.Len = srcHdr.Len
+ dstHdr.Cap = srcHdr.Cap
+ n, err := primitive.CopyUint32SliceOut(s, StackBottomMagic, dst)
+ // Ensure src doesn't get GCed until we're done using it through dst.
+ runtime.KeepAlive(src)
+ return n + nNull, err
+ default:
+ panic("Unsupported arch width")
+ }
+}
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 668f47802..1d88db12f 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -183,9 +183,9 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
if initArgs.MountNamespaceVFS2 == nil {
// Set initArgs so that 'ctx' returns the namespace.
//
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
+ // Add a reference to the namespace, which is transferred to the new process.
initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ initArgs.MountNamespaceVFS2.IncRef()
}
} else {
if initArgs.MountNamespace == nil {
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 71c59287c..14a8bf9cd 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 0b701a289..655ea549b 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -16,6 +16,8 @@
package tundev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -26,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -84,7 +87,16 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := fd.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := fd.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9379a4d7b..6b7b451b8 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netstack",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 5f8c9b5a2..19ffdec47 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -15,6 +15,8 @@
package dev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -25,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -60,7 +63,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod
}
// GetFile implements fs.InodeOperations.GetFile.
-func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
}
@@ -80,12 +83,12 @@ type netTunFileOperations struct {
var _ fs.FileOperations = (*netTunFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (fops *netTunFileOperations) Release(ctx context.Context) {
- fops.device.Release(ctx)
+func (n *netTunFileOperations) Release(ctx context.Context) {
+ n.device.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
request := args[1].Uint()
data := args[2].Pointer()
@@ -109,16 +112,25 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := n.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := n.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
- copy(req.IFName[:], fops.device.Name())
+ copy(req.IFName[:], n.device.Name())
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
// there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
- flags := fops.device.Flags() | linux.IFF_NOFILTER
+ flags := n.device.Flags() | linux.IFF_NOFILTER
usermem.ByteOrder.PutUint16(req.Data[:], flags)
_, err := req.CopyOut(t, data)
@@ -130,41 +142,41 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
}
// Write implements fs.FileOperations.Write.
-func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
data := make([]byte, src.NumBytes())
if _, err := src.CopyIn(ctx, data); err != nil {
return 0, err
}
- return fops.device.Write(data)
+ return n.device.Write(data)
}
// Read implements fs.FileOperations.Read.
-func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- data, err := fops.device.Read()
+func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := n.device.Read()
if err != nil {
return 0, err
}
- n, err := dst.CopyOut(ctx, data)
- if n > 0 && n < len(data) {
+ bytesCopied, err := dst.CopyOut(ctx, data)
+ if bytesCopied > 0 && bytesCopied < len(data) {
// Not an error for partial copying. Packet truncated.
err = nil
}
- return int64(n), err
+ return int64(bytesCopied), err
}
// Readiness implements watier.Waitable.Readiness.
-func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return fops.device.Readiness(mask)
+func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return n.device.Readiness(mask)
}
// EventRegister implements watier.Waitable.EventRegister.
-func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- fops.device.EventRegister(e, mask)
+func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ n.device.EventRegister(e, mask)
}
// EventUnregister implements watier.Waitable.EventUnregister.
-func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
- fops.device.EventUnregister(e)
+func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ n.device.EventUnregister(e)
}
// isNetTunSupported returns whether /dev/net/tun device is supported for s.
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 9197aeb88..1dc409d38 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -84,7 +84,8 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRan
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
// mmap(2): bytes after EOF on the same page are zeroed; pages after EOF are
-// invalid.
+// invalid. fileSize is an upper bound on the file's size; bytes after fileSize
+// will be zeroed without calling readAt.
//
// Fill may read offsets outside of required, but will never read offsets
// outside of optional. It returns a non-nil error if any error occurs, even
@@ -94,7 +95,7 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRan
// * required.Length() > 0.
// * optional.IsSupersetOf(required).
// * required and optional must be page-aligned.
-func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, fileSize uint64, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
gap := frs.LowerBoundGap(required.Start)
for gap.Ok() && gap.Start() < required.End {
if gap.Range().Length() == 0 {
@@ -107,7 +108,21 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
fr, err := mf.AllocateAndFill(gr.Length(), kind, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
var done uint64
for !dsts.IsEmpty() {
- n, err := readAt(ctx, dsts, gr.Start+done)
+ n, err := func() (uint64, error) {
+ off := gr.Start + done
+ if off >= fileSize {
+ return 0, io.EOF
+ }
+ if off+dsts.NumBytes() > fileSize {
+ rd := fileSize - off
+ n, err := readAt(ctx, dsts.TakeFirst64(rd), off)
+ if n == rd && err == nil {
+ return n, io.EOF
+ }
+ return n, err
+ }
+ return readAt(ctx, dsts, off)
+ }()
done += n
dsts = dsts.DropFirst64(n)
if err != nil {
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 9eb6f522e..82eda3e43 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/kernel/time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -444,7 +443,7 @@ func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs.
// time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchAccessTimeLocked(now ktime.Time) {
c.attr.AccessTime = now
c.dirtyAttr.AccessTime = true
}
@@ -461,7 +460,7 @@ func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx contex
// and status change times in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now ktime.Time) {
c.attr.ModificationTime = now
c.dirtyAttr.ModificationTime = true
c.attr.StatusChangeTime = now
@@ -480,7 +479,7 @@ func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) {
// in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now ktime.Time) {
c.attr.StatusChangeTime = now
c.dirtyAttr.StatusChangeTime = true
}
@@ -645,7 +644,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
End: fs.OffsetPageEnd(int64(gapMR.End)),
}
optMR := gap.Range()
- err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
+ err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), uint64(rw.c.attr.Size), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
mem.MarkEvictable(rw.c, pgalloc.EvictableRange{optMR.Start, optMR.End})
seg, gap = rw.c.cache.Find(uint64(rw.offset))
if !seg.Ok() {
@@ -672,9 +671,6 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
-
- default:
- break
}
}
unlock()
@@ -768,9 +764,6 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
-
- default:
- break
}
}
rw.maybeGrowFile()
@@ -877,7 +870,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option
}
mf := c.mfp.MemoryFile()
- cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
+ cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), uint64(c.attr.Size), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
var ts []memmap.Translation
var translatedEnd uint64
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 103bfc600..22d658acf 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -84,6 +84,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo
"auxv": newAuxvec(t, msrc),
"cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
"comm": newComm(t, msrc),
+ "cwd": newCwd(t, msrc),
"environ": newExecArgInode(t, msrc, environExecArg),
"exe": newExe(t, msrc),
"fd": newFdDir(t, msrc),
@@ -300,6 +301,49 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
return exec.PathnameWithDeleted(ctx), nil
}
+// cwd is an fs.InodeOperations symlink for the /proc/PID/cwd file.
+//
+// +stateify savable
+type cwd struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ cwdSymlink := &cwd{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ t: t,
+ }
+ return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t)
+}
+
+// Readlink implements fs.InodeOperations.
+func (e *cwd) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !kernel.ContextCanTrace(ctx, e.t, false) {
+ return "", syserror.EACCES
+ }
+ if err := checkTaskState(e.t); err != nil {
+ return "", err
+ }
+ cwd := e.t.FSContext().WorkingDirectory()
+ if cwd == nil {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer cwd.DecRef(ctx)
+
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ name, _ := cwd.FullName(root)
+ return name, nil
+}
+
// namespaceSymlink represents a symlink in the namespacefs, such as the files
// in /proc/<pid>/ns.
//
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index 1dc75291d..fc0498f17 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -613,7 +613,7 @@ func (f *fileInodeOperations) Translate(ctx context.Context, required, optional
}
mf := f.kernel.MemoryFile()
- cerr := f.data.Fill(ctx, required, optional, mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ cerr := f.data.Fill(ctx, required, optional, uint64(f.attr.Size), mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
// Newly-allocated pages are zeroed, so we don't need to do anything.
return dsts.NumBytes(), nil
})
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index 2f5a43b84..124bc95ed 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -121,6 +121,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
root := mns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
for _, p := range paths {
if !path.IsAbs(p) {
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
index 936fd3932..1f8684dc6 100644
--- a/pkg/sentry/fs/user/user.go
+++ b/pkg/sentry/fs/user/user.go
@@ -105,6 +105,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.
const defaultHome = "/"
root := mns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
creds := auth.CredentialsFromContext(ctx)
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 323506d33..be0900030 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// remainingTraversals is not configurable in VFS2, all callers are using the
// default anyways.
func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
- vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ vfsObj := l.root.Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 48e13613a..84baaac66 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index 903135fae..d5c5aaa8c 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -37,27 +37,51 @@ const Name = "devpts"
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
-type FilesystemType struct{}
+type FilesystemType struct {
+ initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ initErr error
+
+ // fs backs all mounts of this FilesystemType. root is fs' root. fs and root
+ // are immutable.
+ fs *vfs.Filesystem
+ root *vfs.Dentry
+}
// Name implements vfs.FilesystemType.Name.
-func (FilesystemType) Name() string {
+func (*FilesystemType) Name() string {
return Name
}
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// No data allowed.
if opts.Data != "" {
return nil, nil, syserror.EINVAL
}
- fs, root, err := fstype.newFilesystem(vfsObj, creds)
- if err != nil {
- return nil, nil, err
+ fstype.initOnce.Do(func() {
+ fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ if err != nil {
+ fstype.initErr = err
+ return
+ }
+ fstype.fs = fs.VFSFilesystem()
+ fstype.root = root.VFSDentry()
+ })
+ if fstype.initErr != nil {
+ return nil, nil, fstype.initErr
+ }
+ fstype.fs.IncRef()
+ fstype.root.IncRef()
+ return fstype.fs, fstype.root, nil
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (fstype *FilesystemType) Release(ctx context.Context) {
+ if fstype.fs != nil {
+ fstype.root.DecRef(ctx)
+ fstype.fs.DecRef(ctx)
}
- return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil
}
// +stateify savable
@@ -69,7 +93,7 @@ type filesystem struct {
// newFilesystem creates a new devpts filesystem with root directory and ptmx
// master inode. It returns the filesystem and root Dentry.
-func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -87,7 +111,9 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
root.EnableLeakCheck()
- root.dentry.Init(root)
+
+ var rootD kernfs.Dentry
+ rootD.Init(&fs.Filesystem, root)
// Construct the pts master inode and dentry. Linux always uses inode
// id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
@@ -95,15 +121,14 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root: root,
}
master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
- master.dentry.Init(master)
// Add the master as a child of the root.
- links := root.OrderedChildren.Populate(&root.dentry, map[string]*kernfs.Dentry{
- "ptmx": &master.dentry,
+ links := root.OrderedChildren.Populate(map[string]kernfs.Inode{
+ "ptmx": master,
})
root.IncLinks(links)
- return fs, &root.dentry, nil
+ return fs, &rootD, nil
}
// Release implements vfs.FilesystemImpl.Release.
@@ -117,24 +142,19 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type rootInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
kernfs.OrderedChildren
rootInodeRefs
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// master is the master pty inode. Immutable.
master *masterInode
- // root is the root directory inode for this filesystem. Immutable.
- root *rootInode
-
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
@@ -173,21 +193,24 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
- replica.dentry.Init(replica)
i.replicas[idx] = replica
return t, nil
}
// masterClose is called when the master end of t is closed.
-func (i *rootInode) masterClose(t *Terminal) {
+func (i *rootInode) masterClose(ctx context.Context, t *Terminal) {
i.mu.Lock()
defer i.mu.Unlock()
// Sanity check that replica with idx exists.
- if _, ok := i.replicas[t.n]; !ok {
+ ri, ok := i.replicas[t.n]
+ if !ok {
panic(fmt.Sprintf("pty with index %d does not exist", t.n))
}
+
+ // Drop the ref on replica inode taken during rootInode.allocateTerminal.
+ ri.DecRef(ctx)
delete(i.replicas, t.n)
}
@@ -203,16 +226,22 @@ func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.D
}
// Lookup implements kernfs.Inode.Lookup.
-func (i *rootInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
+ // Check if a static entry was looked up.
+ if d, err := i.OrderedChildren.Lookup(ctx, name); err == nil {
+ return d, nil
+ }
+
+ // Not a static entry.
idx, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
}
i.mu.Lock()
defer i.mu.Unlock()
- if si, ok := i.replicas[uint32(idx)]; ok {
- si.dentry.IncRef()
- return &si.dentry, nil
+ if ri, ok := i.replicas[uint32(idx)]; ok {
+ ri.IncRef() // This ref is passed to the dentry upon creation via Init.
+ return ri, nil
}
return nil, syserror.ENOENT
@@ -243,8 +272,8 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback,
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *rootInode) DecRef(context.Context) {
- i.rootInodeRefs.DecRef(i.Destroy)
+func (i *rootInode) DecRef(ctx context.Context) {
+ i.rootInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 69c2fe951..fda30fb93 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -42,9 +42,6 @@ type masterInode struct {
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// root is the devpts root inode.
root *rootInode
}
@@ -103,7 +100,7 @@ var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
func (mfd *masterFileDescription) Release(ctx context.Context) {
- mfd.inode.root.masterClose(mfd.t)
+ mfd.inode.root.masterClose(ctx, mfd.t)
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
index 6515c5536..70c68cf0a 100644
--- a/pkg/sentry/fsimpl/devpts/replica.go
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -41,9 +41,6 @@ type replicaInode struct {
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// root is the devpts root inode.
root *rootInode
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index 6d1753080..e6fe0fc0d 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -71,6 +71,15 @@ func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virtua
return fst.fs, fst.root, nil
}
+// Release implements vfs.FilesystemType.Release.
+func (fst *FilesystemType) Release(ctx context.Context) {
+ if fst.fs != nil {
+ // Release the original reference obtained when creating the filesystem.
+ fst.root.DecRef(ctx)
+ fst.fs.DecRef(ctx)
+ }
+}
+
// Accessor allows devices to create device special files in devtmpfs.
type Accessor struct {
vfsObj *vfs.VirtualFilesystem
@@ -86,10 +95,13 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth
if err != nil {
return nil, err
}
+ // Pass a reference on root to the Accessor.
+ root := mntns.Root()
+ root.IncRef()
return &Accessor{
vfsObj: vfsObj,
mntns: mntns,
- root: mntns.Root(),
+ root: root,
creds: creds,
}, nil
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index 3a38b8bb4..e058eda7a 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -53,6 +53,7 @@ func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.Virtu
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
root := mntns.Root()
+ root.IncRef()
devpop := vfs.PathOperation{
Root: root,
Start: root,
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index abc610ef3..7b1eec3da 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/fd",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
@@ -86,9 +88,9 @@ go_test(
library = ":ext",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index c349b886e..2ee7cc7ac 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -70,6 +70,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
}
root := mntns.Root()
+ root.IncRef()
tearDown := func() {
root.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index 8bb104ff0..1165234f9 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -18,7 +18,7 @@ import (
"io"
"math"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,19 +34,19 @@ type blockMapFile struct {
// directBlks are the direct blocks numbers. The physical blocks pointed by
// these holds file data. Contains file blocks 0 to 11.
- directBlks [numDirectBlks]uint32
+ directBlks [numDirectBlks]primitive.Uint32
// indirectBlk is the physical block which contains (blkSize/4) direct block
// numbers (as uint32 integers).
- indirectBlk uint32
+ indirectBlk primitive.Uint32
// doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
// block numbers (as uint32 integers).
- doubleIndirectBlk uint32
+ doubleIndirectBlk primitive.Uint32
// tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
// indirect block numbers (as uint32 integers).
- tripleIndirectBlk uint32
+ tripleIndirectBlk primitive.Uint32
// coverage at (i)th index indicates the amount of file data a node at
// height (i) covers. Height 0 is the direct block.
@@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
}
blkMap := file.regFile.inode.diskInode.Data()
- binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
- binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ for i := 0; i < numDirectBlks; i++ {
+ file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4])
+ }
+ file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4])
+ file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4])
+ file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4])
return file, nil
}
@@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
switch {
case offset < dirBlksEnd:
// Direct block.
- curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:])
case offset < indirBlkEnd:
// Indirect block.
- curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:])
case offset < doubIndirBlkEnd:
// Doubly indirect block.
- curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:])
default:
// Triply indirect block.
- curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:])
}
read += curR
@@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
read := 0
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
- var childPhyBlk uint32
+ var childPhyBlk primitive.Uint32
err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
- n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:])
read += n
if err != nil {
return read, err
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 6fa84e7aa..ed98b482e 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -20,7 +20,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
var fileData []byte
blkNums := newBlkNumGen()
- var data []byte
+ off := 0
+ data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes())
// Write the direct blocks.
for i := 0; i < numDirectBlks; i++ {
- curBlkNum := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(data[off:])
+ off += curBlkNum.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...)
}
// Write to indirect block.
- indirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
-
- // Write to indirect block.
- doublyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
-
- // Write to indirect block.
- triplyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+ indirectBlk := primitive.Uint32(blkNums.next())
+ indirectBlk.MarshalBytes(data[off:])
+ off += indirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...)
+
+ // Write to double indirect block.
+ doublyIndirectBlk := primitive.Uint32(blkNums.next())
+ doublyIndirectBlk.MarshalBytes(data[off:])
+ off += doublyIndirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...)
+
+ // Write to triple indirect block.
+ triplyIndirectBlk := primitive.Uint32(blkNums.next())
+ triplyIndirectBlk.MarshalBytes(data[off:])
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...)
args := inodeArgs{
fs: &filesystem{
@@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN
var fileData []byte
for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
- curBlkNum := blkNums.next()
- copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
- fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(disk[off : off+4])
+ fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...)
}
return fileData
}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 452450d82..0ad79b381 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -16,7 +16,6 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -100,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
} else {
curDirent.diskDirent = &disklayout.DirentOld{}
}
- binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+ curDirent.diskDirent.UnmarshalBytes(buf)
if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
// Inode number and name length fields being set to 0 is used to indicate
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 9bd9c76c0..d98a05dd8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -22,10 +22,11 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/marshal",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
index ad6f4fef8..0d56ae9da 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// BlockGroup represents a Linux ext block group descriptor. An ext file system
// is split into a series of block groups. This provides an access layer to
// information needed to access and use a block group.
@@ -30,6 +34,8 @@ package disklayout
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors.
type BlockGroup interface {
+ marshal.Marshallable
+
// InodeTable returns the absolute block number of the block containing the
// inode table. This points to an array of Inode structs. Inode tables are
// statically allocated at mkfs time. The superblock records the number of
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
index 3e16c76db..a35fa22a0 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
@@ -17,6 +17,8 @@ package disklayout
// BlockGroup32Bit emulates the first half of struct ext4_group_desc in
// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and
// 32-bit ext4 filesystems. It implements BlockGroup interface.
+//
+// +marshal
type BlockGroup32Bit struct {
BlockBitmapLo uint32
InodeBitmapLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
index 9a809197a..d54d1d345 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
@@ -18,6 +18,8 @@ package disklayout
// It is the block group descriptor struct for 64-bit ext4 filesystems.
// It implements BlockGroup interface. It is an extension of the 32-bit
// version of BlockGroup.
+//
+// +marshal
type BlockGroup64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
index 0ef4294c0..e4ce484e4 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
@@ -21,6 +21,8 @@ import (
// TestBlockGroupSize tests that the block group descriptor structs are of the
// correct size.
func TestBlockGroupSize(t *testing.T) {
- assertSize(t, BlockGroup32Bit{}, 32)
- assertSize(t, BlockGroup64Bit{}, 64)
+ var bgSmall BlockGroup32Bit
+ assertSize(t, &bgSmall, 32)
+ var bgBig BlockGroup64Bit
+ assertSize(t, &bgBig, 64)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
index 417b6cf65..568c8cb4c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -15,6 +15,7 @@
package disklayout
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -51,6 +52,8 @@ var (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories.
type Dirent interface {
+ marshal.Marshallable
+
// Inode returns the absolute inode number of the underlying inode.
// Inode number 0 signifies an unused dirent.
Inode() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
index 29ae4a5c2..51f9c2946 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
@@ -29,12 +29,14 @@ import (
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentNew struct {
InodeNumber uint32
RecordLength uint16
NameLength uint8
FileTypeRaw uint8
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentNew implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
index 6fff12a6e..d4b19e086 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
@@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs"
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentOld struct {
InodeNumber uint32
RecordLength uint16
NameLength uint16
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentOld implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
index 934919f8a..3486864dc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -21,6 +21,8 @@ import (
// TestDirentSize tests that the dirent structs are of the correct
// size.
func TestDirentSize(t *testing.T) {
- assertSize(t, DirentOld{}, uintptr(DirentSize))
- assertSize(t, DirentNew{}, uintptr(DirentSize))
+ var dOld DirentOld
+ assertSize(t, &dOld, DirentSize)
+ var dNew DirentNew
+ assertSize(t, &dNew, DirentSize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
index bdf4e2132..0834e9ba8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
@@ -36,8 +36,6 @@
// escape analysis on an unknown implementation at compile time.
//
// Notes:
-// - All fields in these structs are exported because binary.Read would
-// panic otherwise.
// - All structures on disk are in little-endian order. Only jbd2 (journal)
// structures are in big-endian order.
// - All OS dependent fields in these structures will be interpretted using
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 4110649ab..b13999bfc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// Extents were introduced in ext4 and provide huge performance gains in terms
// data locality and reduced metadata block usage. Extents are organized in
// extent trees. The root node is contained in inode.BlocksRaw.
@@ -64,6 +68,8 @@ type ExtentNode struct {
// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
+ marshal.Marshallable
+
// FileBlock returns the first file block number covered by this entry.
FileBlock() uint32
@@ -75,6 +81,8 @@ type ExtentEntry interface {
// tree node begins with this and is followed by `NumEntries` number of:
// - Extent if `Depth` == 0
// - ExtentIdx otherwise
+//
+// +marshal
type ExtentHeader struct {
// Magic in the extent magic number, must be 0xf30a.
Magic uint16
@@ -96,6 +104,8 @@ type ExtentHeader struct {
// internal nodes. Sorted in ascending order based on FirstFileBlock since
// Linux does a binary search on this. This points to a block containing the
// child node.
+//
+// +marshal
type ExtentIdx struct {
FirstFileBlock uint32
ChildBlockLo uint32
@@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 {
// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a
// binary search on this. This points to an array of data blocks containing the
// file data. It covers `Length` data blocks starting from `StartBlock`.
+//
+// +marshal
type Extent struct {
FirstFileBlock uint32
Length uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index 8762b90db..c96002e19 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,10 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentHeaderSize)
- assertSize(t, ExtentIdx{}, ExtentEntrySize)
- assertSize(t, Extent{}, ExtentEntrySize)
+ var h ExtentHeader
+ assertSize(t, &h, ExtentHeaderSize)
+ var i ExtentIdx
+ assertSize(t, &i, ExtentEntrySize)
+ var e Extent
+ assertSize(t, &e, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
index 88ae913f5..ef25040a9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
@@ -16,6 +16,7 @@ package disklayout
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
@@ -38,6 +39,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes.
type Inode interface {
+ marshal.Marshallable
+
// Mode returns the linux file mode which is majorly used to extract
// information like:
// - File permissions (read/write/execute by user/group/others).
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
index 8f9f574ce..a4503f5cf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
@@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time"
// are used to provide nanoscond precision. Hence, these timestamps will now
// overflow in May 2446.
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+//
+// +marshal
type InodeNew struct {
InodeOld
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
index db25b11b6..e6b28babf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
@@ -30,6 +30,8 @@ const (
//
// All fields representing time are in seconds since the epoch. Which means that
// they will overflow in January 2038.
+//
+// +marshal
type InodeOld struct {
ModeRaw uint16
UIDLo uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
index dd03ee50e..90744e956 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
@@ -24,10 +24,12 @@ import (
// TestInodeSize tests that the inode structs are of the correct size.
func TestInodeSize(t *testing.T) {
- assertSize(t, InodeOld{}, OldInodeSize)
+ var iOld InodeOld
+ assertSize(t, &iOld, OldInodeSize)
// This was updated from 156 bytes to 160 bytes in Oct 2015.
- assertSize(t, InodeNew{}, 160)
+ var iNew InodeNew
+ assertSize(t, &iNew, 160)
}
// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
index 8bb327006..70948ebe9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
const (
// SbOffset is the absolute offset at which the superblock is placed.
SbOffset = 1024
@@ -38,6 +42,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block.
type SuperBlock interface {
+ marshal.Marshallable
+
// InodesCount returns the total number of inodes in this filesystem.
InodesCount() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
index 53e515fd3..4dc6080fb 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
@@ -17,6 +17,8 @@ package disklayout
// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
// RevLevel = DynamicRev and 64-bit feature is disabled.
+//
+// +marshal
type SuperBlock32Bit struct {
// We embed the old superblock struct here because the 32-bit version is just
// an extension of the old version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
index 7c1053fb4..2c9039327 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
@@ -19,6 +19,8 @@ package disklayout
// 1024 bytes (smallest possible block size) and hence the superblock always
// fits in no more than one data block. Should only be used when the 64-bit
// feature is set.
+//
+// +marshal
type SuperBlock64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
index 9221e0251..e4709f23c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
@@ -16,6 +16,8 @@ package disklayout
// SuperBlockOld implements SuperBlock and represents the old version of the
// superblock struct. Should be used only if RevLevel = OldRev.
+//
+// +marshal
type SuperBlockOld struct {
InodesCountRaw uint32
BlocksCountLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
index 463b5ba21..b734b6987 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
@@ -21,7 +21,10 @@ import (
// TestSuperBlockSize tests that the superblock structs are of the correct
// size.
func TestSuperBlockSize(t *testing.T) {
- assertSize(t, SuperBlockOld{}, 84)
- assertSize(t, SuperBlock32Bit{}, 336)
- assertSize(t, SuperBlock64Bit{}, 1024)
+ var sbOld SuperBlockOld
+ assertSize(t, &sbOld, 84)
+ var sb32 SuperBlock32Bit
+ assertSize(t, &sb32, 336)
+ var sb64 SuperBlock64Bit
+ assertSize(t, &sb64, 1024)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
index 9c63f04c0..a4bc08411 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
@@ -18,13 +18,13 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
)
-func assertSize(t *testing.T, v interface{}, want uintptr) {
+func assertSize(t *testing.T, v marshal.Marshallable, want int) {
t.Helper()
- if got := binary.Size(v); got != want {
+ if got := v.SizeBytes(); got != want {
t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got)
}
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index aca258d40..38fb7962b 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -38,9 +38,6 @@ const Name = "ext"
// +stateify savable
type FilesystemType struct{}
-// Compiles only if FilesystemType implements vfs.FilesystemType.
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// getDeviceFd returns an io.ReaderAt to the underlying device.
// Currently there are two ways of mounting an ext(2/3/4) fs:
// 1. Specify a mount with our internal special MountType in the OCI spec.
@@ -101,6 +98,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 0989558cd..d9fd4590c 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -82,6 +82,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
}
root := mntns.Root()
+ root.IncRef()
tearDown := func() {
root.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 04917d762..778460107 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -18,7 +18,6 @@ import (
"io"
"sort"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -60,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
+ f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize])
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -79,7 +78,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
+ curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize])
f.root.Entries[i].Entry = curEntry
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index cd10d46ee..985f76ac0 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -21,7 +21,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
// writeTree writes the tree represented by `root` to the inode and disk. It
// also writes random file data on disk.
func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
- rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
+ rootData := in.diskInode.Data()
+ root.Header.MarshalBytes(rootData)
+ off := root.Header.SizeBytes()
for _, ep := range root.Entries {
- rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(rootData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(in.diskInode.Data(), rootData)
-
var fileData []byte
for _, ep := range root.Entries {
if root.Header.Height == 0 {
@@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl
// writeTreeToDisk is the recursive step for writeTree which writes the tree
// on the disk only. Also writes random file data on disk.
func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
- nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
+ nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:]
+ curNode.Node.Header.MarshalBytes(nodeData)
+ off := curNode.Node.Header.SizeBytes()
for _, ep := range curNode.Node.Entries {
- nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(nodeData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
-
var fileData []byte
for _, ep := range curNode.Node.Entries {
if curNode.Node.Header.Height == 0 {
diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
index d8b728f8c..58ef7b9b8 100644
--- a/pkg/sentry/fsimpl/ext/utils.go
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -17,21 +17,21 @@ package ext
import (
"io"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// readFromDisk performs a binary read from disk into the given struct from
// the absolute offset provided.
-func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
- n := binary.Size(v)
+func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error {
+ n := v.SizeBytes()
buf := make([]byte, n)
if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
return syserror.EIO
}
- binary.Unmarshal(buf, binary.LittleEndian, v)
+ v.UnmarshalBytes(buf)
return nil
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 65786e42a..e39df21c6 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -98,6 +98,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
@@ -249,14 +252,12 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type inode struct {
inodeRefs
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.OrderedChildren
- dentry kernfs.Dentry
-
// the owning filesystem. fs is immutable.
fs *filesystem
@@ -284,26 +285,24 @@ type inode struct {
}
func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
- i := &inode{fs: fs}
+ i := &inode{fs: fs, nodeID: 1}
i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
- i.dentry.Init(i)
- i.nodeID = 1
- return &i.dentry
+ var d kernfs.Dentry
+ d.Init(&fs.Filesystem, i)
+ return &d
}
-func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry {
+func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
i := &inode{fs: fs, nodeID: nodeID}
creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
- i.dentry.Init(i)
-
- return &i.dentry
+ return i
}
// Open implements kernfs.Inode.Open.
@@ -410,23 +409,27 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
}
// Lookup implements kernfs.Inode.Lookup.
-func (i *inode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
in := linux.FUSELookupIn{Name: name}
return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in)
}
+// Keep implements kernfs.Inode.Keep.
+func (i *inode) Keep() bool {
+ // Return true so that kernfs keeps the new dentry pointing to this
+ // inode in the dentry tree. This is needed because inodes created via
+ // Lookup are not temporary. They might refer to existing files on server
+ // that can be Unlink'd/Rmdir'd.
+ return true
+}
+
// IterDirents implements kernfs.Inode.IterDirents.
func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
-// Valid implements kernfs.Inode.Valid.
-func (*inode) Valid(ctx context.Context) bool {
- return true
-}
-
// NewFile implements kernfs.Inode.NewFile.
-func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID)
@@ -444,7 +447,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions)
}
// NewNode implements kernfs.Inode.NewNode.
-func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (kernfs.Inode, error) {
in := linux.FUSEMknodIn{
MknodMeta: linux.FUSEMknodMeta{
Mode: uint32(opts.Mode),
@@ -457,7 +460,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions)
}
// NewSymlink implements kernfs.Inode.NewSymlink.
-func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.Dentry, error) {
+func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) {
in := linux.FUSESymLinkIn{
Name: name,
Target: target,
@@ -466,7 +469,7 @@ func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.De
}
// Unlink implements kernfs.Inode.Unlink.
-func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) error {
+func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) error {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
@@ -482,14 +485,11 @@ func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) e
return err
}
// only return error, discard res.
- if err := res.Error(); err != nil {
- return err
- }
- return i.dentry.RemoveChildLocked(name, child)
+ return res.Error()
}
// NewDir implements kernfs.Inode.NewDir.
-func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
in := linux.FUSEMkdirIn{
MkdirMeta: linux.FUSEMkdirMeta{
Mode: uint32(opts.Mode),
@@ -501,7 +501,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions)
}
// RmDir implements kernfs.Inode.RmDir.
-func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) error {
+func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) error {
fusefs := i.fs
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
@@ -515,16 +515,12 @@ func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) er
if err != nil {
return err
}
- if err := res.Error(); err != nil {
- return err
- }
-
- return i.dentry.RemoveChildLocked(name, child)
+ return res.Error()
}
// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response.
// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP.
-func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*kernfs.Dentry, error) {
+func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (kernfs.Inode, error) {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
@@ -734,8 +730,8 @@ func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptio
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *inode) DecRef(context.Context) {
- i.inodeRefs.DecRef(i.Destroy)
+func (i *inode) DecRef(ctx context.Context) {
+ i.inodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 16787116f..ad0afc41b 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -52,6 +52,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/p9",
+ "//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 8608471f8..f1dad1b08 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -272,6 +272,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index eeaf6e444..f8b19bae7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -395,7 +395,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
End: gapEnd,
}
optMR := gap.Range()
- err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, h.readToBlocksAt)
+ err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), rw.d.size, mf, usage.PageCache, h.readToBlocksAt)
mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End})
seg, gap = rw.d.cache.Find(rw.off)
if !seg.Ok() {
@@ -403,10 +403,10 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
rw.d.handleMu.RUnlock()
return done, err
}
- // err might have occurred in part of gap.Range() outside
- // gapMR. Forget about it for now; if the error matters and
- // persists, we'll run into it again in a later iteration of
- // this loop.
+ // err might have occurred in part of gap.Range() outside gapMR
+ // (in particular, gap.End() might be beyond EOF). Forget about
+ // it for now; if the error matters and persists, we'll run
+ // into it again in a later iteration of this loop.
} else {
// Read directly from the file.
gapDsts := dsts.TakeFirst64(gapMR.Length())
@@ -780,7 +780,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab
mf := d.fs.mfp.MemoryFile()
h := d.readHandleLocked()
- cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, h.readToBlocksAt)
+ cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), d.size, mf, usage.PageCache, h.readToBlocksAt)
var ts []memmap.Translation
var translatedEnd uint64
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index ffe4ddb32..698e913fe 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -118,7 +118,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
if err != nil {
return nil, err
}
- d.Init(i)
+ d.Init(&fs.Filesystem, i)
// i.open will take a reference on d.
defer d.DecRef(ctx)
@@ -151,6 +151,9 @@ func (filesystemType) Name() string {
return "none"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// NewFilesystem sets up and returns a new hostfs filesystem.
//
// Note that there should only ever be one instance of host.filesystem,
@@ -195,6 +198,7 @@ type inode struct {
kernfs.InodeNoStatFS
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
locks vfs.FileLocks
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 131145b85..8a447e29f 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -348,10 +348,10 @@ func (e *SCMConnectedEndpoint) Init() error {
func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
e.DecRef(func() {
e.mu.Lock()
+ fdnotifier.RemoveFD(int32(e.fd))
if err := syscall.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
}
- fdnotifier.RemoveFD(int32(e.fd))
e.destroyLocked()
e.mu.Unlock()
})
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 5e91e0536..858cc24ce 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -70,6 +70,17 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "synthetic_directory_refs",
+ out = "synthetic_directory_refs.go",
+ package = "kernfs",
+ prefix = "syntheticDirectory",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "syntheticDirectory",
+ },
+)
+
go_library(
name = "kernfs",
srcs = [
@@ -84,6 +95,7 @@ go_library(
"static_directory_refs.go",
"symlink.go",
"synthetic_directory.go",
+ "synthetic_directory_refs.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 0a4cd4057..abf1905d6 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -201,12 +201,12 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// these.
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
- stat, err := it.Dentry.inode.Stat(ctx, fd.filesystem(), opts)
+ stat, err := it.inode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
dirent := vfs.Dirent{
- Name: it.Name,
+ Name: it.name,
Type: linux.FileMode(stat.Mode).DirentType(),
Ino: stat.Ino,
NextOff: fd.off + 1,
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 5cc1c4281..6426a55f6 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -89,7 +89,7 @@ afterSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef(ctx)
+ fs.deferDecRefVD(ctx, targetVD)
if err != nil {
return nil, err
}
@@ -120,22 +120,33 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// Cached dentry exists, revalidate.
if !child.inode.Valid(ctx) {
delete(parent.children, name)
- vfsObj.InvalidateDentry(ctx, &child.vfsd)
- fs.deferDecRef(child) // Reference from Lookup.
+ if child.inode.Keep() {
+ // Drop the ref owned by kernfs.
+ fs.deferDecRef(child)
+ }
+ vfsObj.InvalidateDentry(ctx, child.VFSDentry())
child = nil
}
}
if child == nil {
// Dentry isn't cached; it either doesn't exist or failed revalidation.
// Attempt to resolve it via Lookup.
- c, err := parent.inode.Lookup(ctx, name)
+ childInode, err := parent.inode.Lookup(ctx, name)
if err != nil {
return nil, err
}
- // Reference on c (provided by Lookup) will be dropped when the dentry
- // fails validation.
- parent.InsertChildLocked(name, c)
- child = c
+ var newChild Dentry
+ newChild.Init(fs, childInode) // childInode's ref is transferred to newChild.
+ parent.insertChildLocked(name, &newChild)
+ child = &newChild
+
+ // Drop the ref on newChild. This will cause the dentry to get pruned
+ // from the dentry tree by the end of current filesystem operation
+ // (before returning to the VFS layer) if another ref is not picked on
+ // this dentry.
+ if !childInode.Keep() {
+ fs.deferDecRef(&newChild)
+ }
}
return child, nil
}
@@ -191,7 +202,7 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
}
// checkCreateLocked checks that a file named rp.Component() may be created in
-// directory parentVFSD, then returns rp.Component().
+// directory parent, then returns rp.Component().
//
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
@@ -298,9 +309,9 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -324,11 +335,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EPERM
}
- child, err := parent.inode.NewLink(ctx, pc, d.inode)
+ childI, err := parent.inode.NewLink(ctx, pc, d.inode)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
@@ -338,9 +351,9 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -355,14 +368,16 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- child, err := parent.inode.NewDir(ctx, pc, opts)
+ childI, err := parent.inode.NewDir(ctx, pc, opts)
if err != nil {
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
- child = newSyntheticDirectory(rp.Credentials(), opts.Mode)
+ childI = newSyntheticDirectory(rp.Credentials(), opts.Mode)
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
@@ -372,9 +387,9 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -389,11 +404,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- newD, err := parent.inode.NewNode(ctx, pc, opts)
+ newI, err := parent.inode.NewNode(ctx, pc, opts)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, newD)
+ var newD Dentry
+ newD.Init(fs, newI)
+ parent.insertChildLocked(pc, &newD)
return nil
}
@@ -409,22 +426,23 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
return nil, err
}
if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
return nil, err
}
- d.inode.IncRef()
- defer d.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef d to prevent
+ // its destruction while fs.mu is unlocked.
+ d.IncRef()
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
- return d.inode.Open(ctx, rp, d, opts)
+ fd, err := d.inode.Open(ctx, rp, d, opts)
+ d.DecRef(ctx)
+ return fd, err
}
// May create new file.
@@ -438,6 +456,10 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
unlocked = true
}
}
+ // Process all to-be-decref'd dentries at the end at once.
+ // Since we defer unlock() AFTER this, fs.mu is guaranteed to be unlocked
+ // when this is executed.
+ defer fs.processDeferredDecRefs(ctx)
defer unlock()
if rp.Done() {
if rp.MustBeDir() {
@@ -449,14 +471,16 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- d.inode.IncRef()
- defer d.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef d to prevent
+ // its destruction while fs.mu is unlocked.
+ d.IncRef()
unlock()
- return d.inode.Open(ctx, rp, d, opts)
+ fd, err := d.inode.Open(ctx, rp, d, opts)
+ d.DecRef(ctx)
+ return fd, err
}
afterTrailingSymlink:
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return nil, err
}
@@ -487,18 +511,23 @@ afterTrailingSymlink:
}
defer rp.Mount().EndWrite()
// Create and open the child.
- child, err := parent.inode.NewFile(ctx, pc, opts)
+ childI, err := parent.inode.NewFile(ctx, pc, opts)
if err != nil {
return nil, err
}
+ var child Dentry
+ child.Init(fs, childI)
// FIXME(gvisor.dev/issue/1193): Race between checking existence with
- // fs.stepExistingLocked and parent.InsertChild. If possible, we should hold
+ // fs.stepExistingLocked and parent.insertChild. If possible, we should hold
// dirMu from one to the other.
- parent.InsertChild(pc, child)
- child.inode.IncRef()
- defer child.inode.DecRef(ctx)
+ parent.insertChild(pc, &child)
+ // Open may block so we need to unlock fs.mu. IncRef child to prevent
+ // its destruction while fs.mu is unlocked.
+ child.IncRef()
unlock()
- return child.inode.Open(ctx, rp, child, opts)
+ fd, err := child.inode.Open(ctx, rp, &child, opts)
+ child.DecRef(ctx)
+ return fd, err
}
if err != nil {
return nil, err
@@ -514,7 +543,7 @@ afterTrailingSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef(ctx)
+ fs.deferDecRefVD(ctx, targetVD)
if err != nil {
return nil, err
}
@@ -530,18 +559,21 @@ afterTrailingSymlink:
if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- child.inode.IncRef()
- defer child.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef child to prevent
+ // its destruction while fs.mu is unlocked.
+ child.IncRef()
unlock()
- return child.inode.Open(ctx, rp, child, opts)
+ fd, err := child.inode.Open(ctx, rp, child, opts)
+ child.DecRef(ctx)
+ return fd, err
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -560,7 +592,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
fs.mu.Lock()
- defer fs.processDeferredDecRefsLocked(ctx)
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
// Resolve the destination directory first to verify that it's on this
@@ -632,24 +664,27 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
return err
}
- replaced, err := srcDir.inode.Rename(ctx, src.name, pc, src, dstDir)
+ err = srcDir.inode.Rename(ctx, src.name, pc, src.inode, dstDir.inode)
if err != nil {
virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
return err
}
delete(srcDir.children, src.name)
if srcDir != dstDir {
- fs.deferDecRef(srcDir)
- dstDir.IncRef()
+ fs.deferDecRef(srcDir) // child (src) drops ref on old parent.
+ dstDir.IncRef() // child (src) takes a ref on the new parent.
}
src.parent = dstDir
src.name = pc
if dstDir.children == nil {
dstDir.children = make(map[string]*Dentry)
}
+ replaced := dstDir.children[pc]
dstDir.children[pc] = src
var replaceVFSD *vfs.Dentry
if replaced != nil {
+ // deferDecRef so that fs.mu and dstDir.mu are unlocked by then.
+ fs.deferDecRef(replaced)
replaceVFSD = replaced.VFSDentry()
}
virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD)
@@ -659,10 +694,10 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -691,10 +726,13 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
- if err := parentDentry.inode.RmDir(ctx, d.name, d); err != nil {
+ if err := parentDentry.inode.RmDir(ctx, d.name, d.inode); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
+ delete(parentDentry.children, d.name)
+ // Defer decref so that fs.mu and parentDentry.dirMu are unlocked by then.
+ fs.deferDecRef(d)
virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -702,9 +740,9 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -717,9 +755,9 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
// StatAt implements vfs.FilesystemImpl.StatAt.
func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statx{}, err
}
@@ -729,9 +767,9 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statfs{}, err
}
@@ -744,9 +782,9 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -761,21 +799,23 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return err
}
defer rp.Mount().EndWrite()
- child, err := parent.inode.NewSymlink(ctx, pc, target)
+ childI, err := parent.inode.NewSymlink(ctx, pc, target)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -799,10 +839,13 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
- if err := parentDentry.inode.Unlink(ctx, d.name, d); err != nil {
+ if err := parentDentry.inode.Unlink(ctx, d.name, d.inode); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
+ delete(parentDentry.children, d.name)
+ // Defer decref so that fs.mu and parentDentry.dirMu are unlocked by then.
+ fs.deferDecRef(d)
virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -810,9 +853,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -825,9 +868,9 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -838,9 +881,9 @@ func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -851,9 +894,9 @@ func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -864,9 +907,9 @@ func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -880,3 +923,16 @@ func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.mu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*Dentry), b)
}
+
+func (fs *Filesystem) deferDecRefVD(ctx context.Context, vd vfs.VirtualDentry) {
+ if d, ok := vd.Dentry().Impl().(*Dentry); ok && d.fs == fs {
+ // The following is equivalent to vd.DecRef(ctx). This is needed
+ // because if d belongs to this filesystem, we can not DecRef it right
+ // away as we may be holding fs.mu. d.DecRef may acquire fs.mu. So we
+ // defer the DecRef to when locks are dropped.
+ vd.Mount().DecRef(ctx)
+ fs.deferDecRef(d)
+ } else {
+ vd.DecRef(ctx)
+ }
+}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 49210e748..122b10591 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -34,6 +34,7 @@ import (
//
// +stateify savable
type InodeNoopRefCount struct {
+ InodeTemporary
}
// IncRef implements Inode.IncRef.
@@ -57,27 +58,27 @@ func (InodeNoopRefCount) TryIncRef() bool {
type InodeDirectoryNoNewChildren struct{}
// NewFile implements Inode.NewFile.
-func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewLink implements Inode.NewLink.
-func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (Inode, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (Inode, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (Inode, error) {
return nil, syserror.EPERM
}
@@ -88,6 +89,7 @@ func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOpt
//
// +stateify savable
type InodeNotDirectory struct {
+ InodeAlwaysValid
}
// HasChildren implements Inode.HasChildren.
@@ -96,47 +98,47 @@ func (InodeNotDirectory) HasChildren() bool {
}
// NewFile implements Inode.NewFile.
-func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (Inode, error) {
panic("NewFile called on non-directory inode")
}
// NewDir implements Inode.NewDir.
-func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (Inode, error) {
panic("NewDir called on non-directory inode")
}
// NewLink implements Inode.NewLinkink.
-func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*Dentry, error) {
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (Inode, error) {
panic("NewLink called on non-directory inode")
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*Dentry, error) {
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (Inode, error) {
panic("NewSymlink called on non-directory inode")
}
// NewNode implements Inode.NewNode.
-func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (Inode, error) {
panic("NewNode called on non-directory inode")
}
// Unlink implements Inode.Unlink.
-func (InodeNotDirectory) Unlink(context.Context, string, *Dentry) error {
+func (InodeNotDirectory) Unlink(context.Context, string, Inode) error {
panic("Unlink called on non-directory inode")
}
// RmDir implements Inode.RmDir.
-func (InodeNotDirectory) RmDir(context.Context, string, *Dentry) error {
+func (InodeNotDirectory) RmDir(context.Context, string, Inode) error {
panic("RmDir called on non-directory inode")
}
// Rename implements Inode.Rename.
-func (InodeNotDirectory) Rename(context.Context, string, string, *Dentry, *Dentry) (*Dentry, error) {
+func (InodeNotDirectory) Rename(context.Context, string, string, Inode, Inode) error {
panic("Rename called on non-directory inode")
}
// Lookup implements Inode.Lookup.
-func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*Dentry, error) {
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) {
panic("Lookup called on non-directory inode")
}
@@ -145,35 +147,6 @@ func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDiren
panic("IterDirents called on non-directory inode")
}
-// Valid implements Inode.Valid.
-func (InodeNotDirectory) Valid(context.Context) bool {
- return true
-}
-
-// InodeNoDynamicLookup partially implements the Inode interface, specifically
-// the inodeDynamicLookup sub interface. Directory inodes that do not support
-// dymanic entries (i.e. entries that are not "hashed" into the
-// vfs.Dentry.children) can embed this to provide no-op implementations for
-// functions related to dynamic entries.
-//
-// +stateify savable
-type InodeNoDynamicLookup struct{}
-
-// Lookup implements Inode.Lookup.
-func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*Dentry, error) {
- return nil, syserror.ENOENT
-}
-
-// IterDirents implements Inode.IterDirents.
-func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- return offset, nil
-}
-
-// Valid implements Inode.Valid.
-func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
- return true
-}
-
// InodeNotSymlink partially implements the Inode interface, specifically the
// inodeSymlink sub interface. All inodes that are not symlinks may embed this
// to return the appropriate errors from symlink-related functions.
@@ -273,7 +246,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
// This function can be used by other kernfs-based filesystem implementation to
-// sets the unexported attributes into kernfs.InodeAttrs.
+// sets the unexported attributes into InodeAttrs.
func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
@@ -344,8 +317,9 @@ func (a *InodeAttrs) DecLinks() {
// +stateify savable
type slot struct {
- Name string
- Dentry *Dentry
+ name string
+ inode Inode
+ static bool
slotEntry
}
@@ -361,10 +335,18 @@ type OrderedChildrenOptions struct {
}
// OrderedChildren partially implements the Inode interface. OrderedChildren can
-// be embedded in directory inodes to keep track of the children in the
+// be embedded in directory inodes to keep track of children in the
// directory, and can then be used to implement a generic directory FD -- see
-// GenericDirectoryFD. OrderedChildren is not compatible with dynamic
-// directories.
+// GenericDirectoryFD.
+//
+// OrderedChildren can represent a node in an Inode tree. The children inodes
+// might be directories themselves using OrderedChildren; hence extending the
+// tree. The parent inode (OrderedChildren user) holds a ref on all its static
+// children. This lets the static inodes outlive their associated dentry.
+// While the dentry might have to be regenerated via a Lookup() call, we can
+// keep reusing the same static inode. These static children inodes are finally
+// DecRef'd when this directory inode is being destroyed. This makes
+// OrderedChildren suitable for static directory entries as well.
//
// Must be initialize with Init before first use.
//
@@ -388,33 +370,63 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
// Destroy clears the children stored in o. It should be called by structs
// embedding OrderedChildren upon destruction, i.e. when their reference count
// reaches zero.
-func (o *OrderedChildren) Destroy() {
+func (o *OrderedChildren) Destroy(ctx context.Context) {
o.mu.Lock()
defer o.mu.Unlock()
+ // Drop the ref that o owns on the static inodes it holds.
+ for _, s := range o.set {
+ if s.static {
+ s.inode.DecRef(ctx)
+ }
+ }
o.order.Reset()
o.set = nil
}
-// Populate inserts children into this OrderedChildren, and d's dentry
-// cache. Populate returns the number of directories inserted, which the caller
+// Populate inserts static children into this OrderedChildren.
+// Populate returns the number of directories inserted, which the caller
// may use to update the link count for the parent directory.
//
-// Precondition: d must represent a directory inode. children must not contain
-// any conflicting entries already in o.
-func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
+// Precondition:
+// * d must represent a directory inode.
+// * children must not contain any conflicting entries already in o.
+// * Caller must hold a reference on all inodes passed.
+//
+// Postcondition: Caller's references on inodes are transferred to o.
+func (o *OrderedChildren) Populate(children map[string]Inode) uint32 {
var links uint32
for name, child := range children {
- if child.isDir() {
+ if child.Mode().IsDir() {
links++
}
- if err := o.Insert(name, child); err != nil {
- panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
+ if err := o.insert(name, child, true); err != nil {
+ panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v)", name, child))
}
- d.InsertChild(name, child)
}
return links
}
+// Lookup implements Inode.Lookup.
+func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error) {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+
+ s, ok := o.set[name]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+
+ s.inode.IncRef() // This ref is passed to the dentry upon creation via Init.
+ return s.inode, nil
+}
+
+// IterDirents implements Inode.IterDirents.
+func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ // All entries from OrderedChildren have already been handled in
+ // GenericDirectoryFD.IterDirents.
+ return offset, nil
+}
+
// HasChildren implements Inode.HasChildren.
func (o *OrderedChildren) HasChildren() bool {
o.mu.RLock()
@@ -422,17 +434,27 @@ func (o *OrderedChildren) HasChildren() bool {
return len(o.set) > 0
}
-// Insert inserts child into o. This ignores the writability of o, as this is
-// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
-func (o *OrderedChildren) Insert(name string, child *Dentry) error {
+// Insert inserts a dynamic child into o. This ignores the writability of o, as
+// this is not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
+func (o *OrderedChildren) Insert(name string, child Inode) error {
+ return o.insert(name, child, false)
+}
+
+// insert inserts child into o.
+//
+// Precondition: Caller must be holding a ref on child if static is true.
+//
+// Postcondition: Caller's ref on child is transferred to o if static is true.
+func (o *OrderedChildren) insert(name string, child Inode, static bool) error {
o.mu.Lock()
defer o.mu.Unlock()
if _, ok := o.set[name]; ok {
return syserror.EEXIST
}
s := &slot{
- Name: name,
- Dentry: child,
+ name: name,
+ inode: child,
+ static: static,
}
o.order.PushBack(s)
o.set[name] = s
@@ -442,44 +464,49 @@ func (o *OrderedChildren) Insert(name string, child *Dentry) error {
// Precondition: caller must hold o.mu for writing.
func (o *OrderedChildren) removeLocked(name string) {
if s, ok := o.set[name]; ok {
+ if s.static {
+ panic(fmt.Sprintf("removeLocked called on a static inode: %v", s.inode))
+ }
delete(o.set, name)
o.order.Remove(s)
}
}
// Precondition: caller must hold o.mu for writing.
-func (o *OrderedChildren) replaceChildLocked(name string, new *Dentry) *Dentry {
+func (o *OrderedChildren) replaceChildLocked(ctx context.Context, name string, newI Inode) {
if s, ok := o.set[name]; ok {
+ if s.static {
+ panic(fmt.Sprintf("replacing a static inode: %v", s.inode))
+ }
+
// Existing slot with given name, simply replace the dentry.
- var old *Dentry
- old, s.Dentry = s.Dentry, new
- return old
+ s.inode = newI
}
// No existing slot with given name, create and hash new slot.
s := &slot{
- Name: name,
- Dentry: new,
+ name: name,
+ inode: newI,
+ static: false,
}
o.order.PushBack(s)
o.set[name] = s
- return nil
}
// Precondition: caller must hold o.mu for reading or writing.
-func (o *OrderedChildren) checkExistingLocked(name string, child *Dentry) error {
+func (o *OrderedChildren) checkExistingLocked(name string, child Inode) error {
s, ok := o.set[name]
if !ok {
return syserror.ENOENT
}
- if s.Dentry != child {
- panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child))
+ if s.inode != child {
+ panic(fmt.Sprintf("Inode doesn't match what kernfs thinks! OrderedChild: %+v, kernfs: %+v", s.inode, child))
}
return nil
}
// Unlink implements Inode.Unlink.
-func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry) error {
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child Inode) error {
if !o.writable {
return syserror.EPERM
}
@@ -494,8 +521,8 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry
return nil
}
-// Rmdir implements Inode.Rmdir.
-func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *Dentry) error {
+// RmDir implements Inode.RmDir.
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) error {
// We're not responsible for checking that child is a directory, that it's
// empty, or updating any link counts; so this is the same as unlink.
return o.Unlink(ctx, name, child)
@@ -517,13 +544,13 @@ func (renameAcrossDifferentImplementationsError) Error() string {
// that will support Rename.
//
// Postcondition: reference on any replaced dentry transferred to caller.
-func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (*Dentry, error) {
- dst, ok := dstDir.inode.(interface{}).(*OrderedChildren)
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error {
+ dst, ok := dstDir.(interface{}).(*OrderedChildren)
if !ok {
- return nil, renameAcrossDifferentImplementationsError{}
+ return renameAcrossDifferentImplementationsError{}
}
if !o.writable || !dst.writable {
- return nil, syserror.EPERM
+ return syserror.EPERM
}
// Note: There's a potential deadlock below if concurrent calls to Rename
// refer to the same src and dst directories in reverse. We avoid any
@@ -536,12 +563,12 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c
defer dst.mu.Unlock()
}
if err := o.checkExistingLocked(oldname, child); err != nil {
- return nil, err
+ return err
}
// TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
- replaced := dst.replaceChildLocked(newname, child)
- return replaced, nil
+ dst.replaceChildLocked(ctx, newname, child)
+ return nil
}
// nthLocked returns an iterator to the nth child tracked by this object. The
@@ -576,11 +603,12 @@ func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry,
//
// +stateify savable
type StaticDirectory struct {
+ InodeAlwaysValid
InodeAttrs
InodeDirectoryNoNewChildren
- InodeNoDynamicLookup
InodeNoStatFS
InodeNotSymlink
+ InodeTemporary
OrderedChildren
StaticDirectoryRefs
@@ -591,19 +619,16 @@ type StaticDirectory struct {
var _ Inode = (*StaticDirectory)(nil)
// NewStaticDir creates a new static directory and returns its dentry.
-func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry {
+func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
inode.EnableLeakCheck()
- dentry := &Dentry{}
- dentry.Init(inode)
-
inode.OrderedChildren.Init(OrderedChildrenOptions{})
- links := inode.OrderedChildren.Populate(dentry, children)
+ links := inode.OrderedChildren.Populate(children)
inode.IncLinks(links)
- return dentry
+ return inode
}
// Init initializes StaticDirectory.
@@ -615,7 +640,7 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3
s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
-// Open implements kernfs.Inode.Open.
+// Open implements Inode.Open.
func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd, err := NewGenericDirectoryFD(rp.Mount(), d, &s.OrderedChildren, &s.locks, &opts, s.fdOpts)
if err != nil {
@@ -624,26 +649,36 @@ func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *De
return fd.VFSFileDescription(), nil
}
-// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
-// DecRef implements kernfs.Inode.DecRef.
-func (s *StaticDirectory) DecRef(context.Context) {
- s.StaticDirectoryRefs.DecRef(s.Destroy)
+// DecRef implements Inode.DecRef.
+func (s *StaticDirectory) DecRef(ctx context.Context) {
+ s.StaticDirectoryRefs.DecRef(func() { s.Destroy(ctx) })
}
-// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+// InodeAlwaysValid partially implements Inode.
//
// +stateify savable
-type AlwaysValid struct{}
+type InodeAlwaysValid struct{}
-// Valid implements kernfs.inodeDynamicLookup.Valid.
-func (*AlwaysValid) Valid(context.Context) bool {
+// Valid implements Inode.Valid.
+func (*InodeAlwaysValid) Valid(context.Context) bool {
return true
}
+// InodeTemporary partially implements Inode.
+//
+// +stateify savable
+type InodeTemporary struct{}
+
+// Keep implements Inode.Keep.
+func (*InodeTemporary) Keep() bool {
+ return false
+}
+
// InodeNoStatFS partially implements the Inode interface, where the client
// filesystem doesn't support statfs(2).
//
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 6d3d79333..606081e68 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -29,12 +29,16 @@
//
// Reference Model:
//
-// Kernfs dentries represents named pointers to inodes. Dentries and inodes have
+// Kernfs dentries represents named pointers to inodes. Kernfs is solely
+// reponsible for maintaining and modifying its dentry tree; inode
+// implementations can not access the tree. Dentries and inodes have
// independent lifetimes and reference counts. A child dentry unconditionally
// holds a reference on its parent directory's dentry. A dentry also holds a
-// reference on the inode it points to. Multiple dentries can point to the same
-// inode (for example, in the case of hardlinks). File descriptors hold a
-// reference to the dentry they're opened on.
+// reference on the inode it points to (although that might not be the only
+// reference on the inode). Due to this inodes can outlive the dentries that
+// point to them. Multiple dentries can point to the same inode (for example,
+// in the case of hardlinks). File descriptors hold a reference to the dentry
+// they're opened on.
//
// Dentries are guaranteed to exist while holding Filesystem.mu for
// reading. Dropping dentries require holding Filesystem.mu for writing. To
@@ -47,8 +51,8 @@
// kernfs.Dentry.dirMu
// vfs.VirtualFilesystem.mountMu
// vfs.Dentry.mu
-// kernfs.Filesystem.droppedDentriesMu
// (inode implementation locks, if any)
+// kernfs.Filesystem.droppedDentriesMu
package kernfs
import (
@@ -60,7 +64,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
@@ -95,7 +98,7 @@ type Filesystem struct {
// example:
//
// fs.mu.RLock()
- // fs.mu.processDeferredDecRefs()
+ // defer fs.processDeferredDecRefs()
// defer fs.mu.RUnlock()
// ...
// fs.deferDecRef(dentry)
@@ -108,8 +111,7 @@ type Filesystem struct {
// deferDecRef defers dropping a dentry ref until the next call to
// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
-//
-// Precondition: d must not already be pending destruction.
+// This may be called while Filesystem.mu or Dentry.dirMu is locked.
func (fs *Filesystem) deferDecRef(d *Dentry) {
fs.droppedDentriesMu.Lock()
fs.droppedDentries = append(fs.droppedDentries, d)
@@ -118,17 +120,14 @@ func (fs *Filesystem) deferDecRef(d *Dentry) {
// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
// droppedDentries list. See comment on Filesystem.mu.
+//
+// Precondition: Filesystem.mu or Dentry.dirMu must NOT be locked.
func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) {
- fs.mu.Lock()
- fs.processDeferredDecRefsLocked(ctx)
- fs.mu.Unlock()
-}
-
-// Precondition: fs.mu must be held for writing.
-func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) {
fs.droppedDentriesMu.Lock()
for _, d := range fs.droppedDentries {
- d.DecRef(ctx)
+ // Defer the DecRef call so that we are not holding droppedDentriesMu
+ // when DecRef is called.
+ defer d.DecRef(ctx)
}
fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
fs.droppedDentriesMu.Unlock()
@@ -157,17 +156,19 @@ const (
//
// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a
// named reference to an inode. A dentry generally lives as long as it's part of
-// a mounted filesystem tree. Kernfs doesn't cache dentries once all references
-// to them are removed. Dentries hold a single reference to the inode they point
+// a mounted filesystem tree. Kernfs drops dentries once all references to them
+// are dropped. Dentries hold a single reference to the inode they point
// to, and child dentries hold a reference on their parent.
//
// Must be initialized by Init prior to first use.
//
// +stateify savable
type Dentry struct {
+ vfsd vfs.Dentry
DentryRefs
- vfsd vfs.Dentry
+ // fs is the owning filesystem. fs is immutable.
+ fs *Filesystem
// flags caches useful information about the dentry from the inode. See the
// dflags* consts above. Must be accessed by atomic ops.
@@ -192,8 +193,9 @@ type Dentry struct {
// Precondition: Caller must hold a reference on inode.
//
// Postcondition: Caller's reference on inode is transferred to the dentry.
-func (d *Dentry) Init(inode Inode) {
+func (d *Dentry) Init(fs *Filesystem, inode Inode) {
d.vfsd.Init(d)
+ d.fs = fs
d.inode = inode
ftype := inode.Mode().FileType()
if ftype == linux.ModeDirectory {
@@ -222,14 +224,28 @@ func (d *Dentry) isSymlink() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *Dentry) DecRef(ctx context.Context) {
- // Before the destructor is called, Dentry must be removed from VFS' dentry cache.
+ decRefParent := false
+ d.fs.mu.Lock()
d.DentryRefs.DecRef(func() {
d.inode.DecRef(ctx) // IncRef from Init.
d.inode = nil
if d.parent != nil {
- d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild.
+ // We will DecRef d.parent once all locks are dropped.
+ decRefParent = true
+ d.parent.dirMu.Lock()
+ // Remove d from parent.children. It might already have been
+ // removed due to invalidation.
+ if _, ok := d.parent.children[d.name]; ok {
+ delete(d.parent.children, d.name)
+ d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
+ }
+ d.parent.dirMu.Unlock()
}
})
+ d.fs.mu.Unlock()
+ if decRefParent {
+ d.parent.DecRef(ctx) // IncRef from Dentry.insertChild.
+ }
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -247,26 +263,26 @@ func (d *Dentry) Watches() *vfs.Watches {
// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
func (d *Dentry) OnZeroWatches(context.Context) {}
-// InsertChild inserts child into the vfs dentry cache with the given name under
+// insertChild inserts child into the vfs dentry cache with the given name under
// this dentry. This does not update the directory inode, so calling this on its
// own isn't sufficient to insert a child into a directory.
//
// Precondition: d must represent a directory inode.
-func (d *Dentry) InsertChild(name string, child *Dentry) {
+func (d *Dentry) insertChild(name string, child *Dentry) {
d.dirMu.Lock()
- d.InsertChildLocked(name, child)
+ d.insertChildLocked(name, child)
d.dirMu.Unlock()
}
-// InsertChildLocked is equivalent to InsertChild, with additional
+// insertChildLocked is equivalent to insertChild, with additional
// preconditions.
//
// Preconditions:
// * d must represent a directory inode.
// * d.dirMu must be locked.
-func (d *Dentry) InsertChildLocked(name string, child *Dentry) {
+func (d *Dentry) insertChildLocked(name string, child *Dentry) {
if !d.isDir() {
- panic(fmt.Sprintf("InsertChildLocked called on non-directory Dentry: %+v.", d))
+ panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d))
}
d.IncRef() // DecRef in child's Dentry.destroy.
child.parent = d
@@ -277,36 +293,6 @@ func (d *Dentry) InsertChildLocked(name string, child *Dentry) {
d.children[name] = child
}
-// RemoveChild removes child from the vfs dentry cache. This does not update the
-// directory inode or modify the inode to be unlinked. So calling this on its own
-// isn't sufficient to remove a child from a directory.
-//
-// Precondition: d must represent a directory inode.
-func (d *Dentry) RemoveChild(name string, child *Dentry) error {
- d.dirMu.Lock()
- defer d.dirMu.Unlock()
- return d.RemoveChildLocked(name, child)
-}
-
-// RemoveChildLocked is equivalent to RemoveChild, with additional
-// preconditions.
-//
-// Precondition: d.dirMu must be locked.
-func (d *Dentry) RemoveChildLocked(name string, child *Dentry) error {
- if !d.isDir() {
- panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d))
- }
- c, ok := d.children[name]
- if !ok {
- return syserror.ENOENT
- }
- if c != child {
- panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child))
- }
- delete(d.children, name)
- return nil
-}
-
// Inode returns the dentry's inode.
func (d *Dentry) Inode() Inode {
return d.inode
@@ -348,11 +334,6 @@ type Inode interface {
// a blanket implementation for all non-directory inodes.
inodeDirectory
- // Method for inodes that represent dynamic directories and their
- // children. InodeNoDynamicLookup provides a blanket implementation for all
- // non-dynamic-directory inodes.
- inodeDynamicLookup
-
// Open creates a file description for the filesystem object represented by
// this inode. The returned file description should hold a reference on the
// dentry for its lifetime.
@@ -365,6 +346,14 @@ type Inode interface {
// corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem
// doesn't support statfs(2), this should return ENOSYS.
StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error)
+
+ // Keep indicates whether the dentry created after Inode.Lookup should be
+ // kept in the kernfs dentry tree.
+ Keep() bool
+
+ // Valid should return true if this inode is still valid, or needs to
+ // be resolved again by a call to Lookup.
+ Valid(ctx context.Context) bool
}
type inodeRefs interface {
@@ -397,8 +386,8 @@ type inodeMetadata interface {
// Precondition: All methods in this interface may only be called on directory
// inodes.
type inodeDirectory interface {
- // The New{File,Dir,Node,Symlink} methods below should return a new inode
- // hashed into this inode.
+ // The New{File,Dir,Node,Link,Symlink} methods below should return a new inode
+ // that will be hashed into the dentry tree.
//
// These inode constructors are inode-level operations rather than
// filesystem-level operations to allow client filesystems to mix different
@@ -409,60 +398,54 @@ type inodeDirectory interface {
HasChildren() bool
// NewFile creates a new regular file inode.
- NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error)
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (Inode, error)
// NewDir creates a new directory inode.
- NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error)
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (Inode, error)
// NewLink creates a new hardlink to a specified inode in this
// directory. Implementations should create a new kernfs Dentry pointing to
// target, and update target's link count.
- NewLink(ctx context.Context, name string, target Inode) (*Dentry, error)
+ NewLink(ctx context.Context, name string, target Inode) (Inode, error)
// NewSymlink creates a new symbolic link inode.
- NewSymlink(ctx context.Context, name, target string) (*Dentry, error)
+ NewSymlink(ctx context.Context, name, target string) (Inode, error)
// NewNode creates a new filesystem node for a mknod syscall.
- NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error)
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (Inode, error)
// Unlink removes a child dentry from this directory inode.
- Unlink(ctx context.Context, name string, child *Dentry) error
+ Unlink(ctx context.Context, name string, child Inode) error
// RmDir removes an empty child directory from this directory
// inode. Implementations must update the parent directory's link count,
// if required. Implementations are not responsible for checking that child
// is a directory, checking for an empty directory.
- RmDir(ctx context.Context, name string, child *Dentry) error
+ RmDir(ctx context.Context, name string, child Inode) error
// Rename is called on the source directory containing an inode being
// renamed. child should point to the resolved child in the source
- // directory. If Rename replaces a dentry in the destination directory, it
- // should return the replaced dentry or nil otherwise.
+ // directory.
//
// Precondition: Caller must serialize concurrent calls to Rename.
- Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (replaced *Dentry, err error)
-}
+ Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error
-type inodeDynamicLookup interface {
- // Lookup should return an appropriate dentry if name should resolve to a
- // child of this dynamic directory inode. This gives the directory an
- // opportunity on every lookup to resolve additional entries that aren't
- // hashed into the directory. This is only called when the inode is a
- // directory. If the inode is not a directory, or if the directory only
- // contains a static set of children, the implementer can unconditionally
- // return an appropriate error (ENOTDIR and ENOENT respectively).
+ // Lookup should return an appropriate inode if name should resolve to a
+ // child of this directory inode. This gives the directory an opportunity
+ // on every lookup to resolve additional entries. This is only called when
+ // the inode is a directory.
//
- // The child returned by Lookup will be hashed into the VFS dentry tree. Its
- // lifetime can be controlled by the filesystem implementation with an
- // appropriate implementation of Valid.
+ // The child returned by Lookup will be hashed into the VFS dentry tree,
+ // atleast for the duration of the current FS operation.
//
- // Lookup returns the child with an extra reference and the caller owns this
- // reference.
- Lookup(ctx context.Context, name string) (*Dentry, error)
-
- // Valid should return true if this inode is still valid, or needs to
- // be resolved again by a call to Lookup.
- Valid(ctx context.Context) bool
+ // Lookup must return the child with an extra reference whose ownership is
+ // transferred to the dentry that is created to point to that inode. If
+ // Inode.Keep returns false, that new dentry will be dropped at the end of
+ // the current filesystem operation (before returning back to the VFS
+ // layer) if no other ref is picked on that dentry. If Inode.Keep returns
+ // true, then the dentry will be cached into the dentry tree until it is
+ // Unlink'd or RmDir'd.
+ Lookup(ctx context.Context, name string) (Inode, error)
// IterDirents is used to iterate over dynamically created entries. It invokes
// cb on each entry in the directory represented by the Inode.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index e413242dc..82fa19c03 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file."
// RootDentryFn is a generator function for creating the root dentry of a test
// filesystem. See newTestSystem.
-type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
+type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode
// newTestSystem sets up a minimal environment for running a test, including an
// instance of a test filesystem. Tests can control the contents of the
@@ -72,14 +72,11 @@ type file struct {
content string
}
-func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
+func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode {
f := &file{}
f.content = content
f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
-
- d := &kernfs.Dentry{}
- d.Init(f)
- return d
+ return f
}
func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error {
@@ -98,27 +95,23 @@ func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.S
type readonlyDir struct {
readonlyDirRefs
attrs
+ kernfs.InodeAlwaysValid
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNoStatFS
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
locks vfs.FileLocks
-
- dentry kernfs.Dentry
}
-func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &readonlyDir{}
dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
dir.EnableLeakCheck()
- dir.dentry.Init(dir)
-
- dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
-
- return &dir.dentry
+ dir.IncLinks(dir.OrderedChildren.Populate(contents))
+ return dir
}
func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
@@ -131,35 +124,33 @@ func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernf
return fd.VFSFileDescription(), nil
}
-func (d *readonlyDir) DecRef(context.Context) {
- d.readonlyDirRefs.DecRef(d.Destroy)
+func (d *readonlyDir) DecRef(ctx context.Context) {
+ d.readonlyDirRefs.DecRef(func() { d.Destroy(ctx) })
}
type dir struct {
dirRefs
attrs
- kernfs.InodeNoDynamicLookup
+ kernfs.InodeAlwaysValid
kernfs.InodeNotSymlink
- kernfs.OrderedChildren
kernfs.InodeNoStatFS
+ kernfs.InodeTemporary
+ kernfs.OrderedChildren
locks vfs.FileLocks
- fs *filesystem
- dentry kernfs.Dentry
+ fs *filesystem
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &dir{}
dir.fs = fs
dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
dir.EnableLeakCheck()
- dir.dentry.Init(dir)
- dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
-
- return &dir.dentry
+ dir.IncLinks(dir.OrderedChildren.Populate(contents))
+ return dir
}
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
@@ -172,11 +163,11 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry
return fd.VFSFileDescription(), nil
}
-func (d *dir) DecRef(context.Context) {
- d.dirRefs.DecRef(d.Destroy)
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
-func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
dir := d.fs.newDir(creds, opts.Mode, nil)
if err := d.OrderedChildren.Insert(name, dir); err != nil {
@@ -187,7 +178,7 @@ func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*
return dir, nil
}
-func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
f := d.fs.newFile(creds, "")
if err := d.OrderedChildren.Insert(name, f); err != nil {
@@ -197,15 +188,15 @@ func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*
return f, nil
}
-func (*dir) NewLink(context.Context, string, kernfs.Inode) (*kernfs.Dentry, error) {
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
-func (*dir) NewSymlink(context.Context, string, string) (*kernfs.Dentry, error) {
+func (*dir) NewSymlink(context.Context, string, string) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
-func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*kernfs.Dentry, error) {
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
@@ -213,18 +204,22 @@ func (fsType) Name() string {
return "kernfs"
}
+func (fsType) Release(ctx context.Context) {}
+
func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
fs.VFSFilesystem().Init(vfsObj, &fst, fs)
root := fst.rootFn(creds, fs)
- return fs.VFSFilesystem(), root.VFSDentry(), nil
+ var d kernfs.Dentry
+ d.Init(&fs.Filesystem, root)
+ return fs.VFSFilesystem(), d.VFSDentry(), nil
}
// -------------------- Remainder of the file are test cases --------------------
func TestBasic(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"file1": fs.newFile(creds, staticFileContent),
})
})
@@ -233,8 +228,8 @@ func TestBasic(t *testing.T) {
}
func TestMkdirGetDentry(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"dir1": fs.newDir(creds, 0755, nil),
})
})
@@ -248,8 +243,8 @@ func TestMkdirGetDentry(t *testing.T) {
}
func TestReadStaticFile(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"file1": fs.newFile(creds, staticFileContent),
})
})
@@ -274,8 +269,8 @@ func TestReadStaticFile(t *testing.T) {
}
func TestCreateNewFileInStaticDir(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"dir1": fs.newDir(creds, 0755, nil),
})
})
@@ -301,7 +296,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
func TestDirFDReadWrite(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
return fs.newReadonlyDir(creds, 0755, nil)
})
defer sys.Destroy()
@@ -325,11 +320,11 @@ func TestDirFDReadWrite(t *testing.T) {
}
func TestDirFDIterDirents(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
// Fill root with nodes backed by various inode implementations.
"dir1": fs.newReadonlyDir(creds, 0755, nil),
- "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{
"dir3": fs.newDir(creds, 0755, nil),
}),
"file1": fs.newFile(creds, staticFileContent),
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 58a93eaac..934cc6c9e 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -38,13 +38,10 @@ type StaticSymlink struct {
var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) *Dentry {
+func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
inode := &StaticSymlink{}
inode.Init(creds, devMajor, devMinor, ino, target)
-
- d := &Dentry{}
- d.Init(inode)
- return d
+ return inode
}
// Init initializes the instance.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index ea7f073eb..d0ed17b18 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -29,24 +29,22 @@ import (
//
// +stateify savable
type syntheticDirectory struct {
+ InodeAlwaysValid
InodeAttrs
InodeNoStatFS
- InodeNoopRefCount
- InodeNoDynamicLookup
InodeNotSymlink
OrderedChildren
+ syntheticDirectoryRefs
locks vfs.FileLocks
}
var _ Inode = (*syntheticDirectory)(nil)
-func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *Dentry {
+func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode {
inode := &syntheticDirectory{}
inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
- d := &Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
@@ -69,34 +67,46 @@ func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath,
}
// NewFile implements Inode.NewFile.
-func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (Inode, error) {
if !opts.ForSyntheticMountpoint {
return nil, syserror.EPERM
}
- subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
- if err := dir.OrderedChildren.Insert(name, subdird); err != nil {
- subdird.DecRef(ctx)
+ subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ if err := dir.OrderedChildren.Insert(name, subdirI); err != nil {
+ subdirI.DecRef(ctx)
return nil, err
}
- return subdird, nil
+ return subdirI, nil
}
// NewLink implements Inode.NewLink.
-func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) {
+func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (Inode, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*Dentry, error) {
+func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (Inode, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (Inode, error) {
return nil, syserror.EPERM
}
+
+// DecRef implements Inode.DecRef.
+func (dir *syntheticDirectory) DecRef(ctx context.Context) {
+ dir.syntheticDirectoryRefs.DecRef(func() { dir.Destroy(ctx) })
+}
+
+// Keep implements Inode.Keep. This is redundant because inodes will never be
+// created via Lookup and inodes are always valid. Makes sense to return true
+// because these inodes are not temporary and should only be removed on RmDir.
+func (dir *syntheticDirectory) Keep() bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index dfbccd05f..e5f506d2e 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -60,6 +60,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
// FilesystemType.GetFilesystem.
//
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 4e2da4810..e44b79b68 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -39,6 +39,9 @@ func (filesystemType) Name() string {
return "pipefs"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
panic("pipefs.filesystemType.GetFilesystem should never be called")
@@ -165,7 +168,7 @@ func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vf
fs := mnt.Filesystem().Impl().(*filesystem)
inode := newInode(ctx, fs)
var d kernfs.Dentry
- d.Init(inode)
+ d.Init(&fs.Filesystem, inode)
defer d.DecRef(ctx)
return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 05d7948ea..fd70a07de 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -34,13 +34,14 @@ const Name = "proc"
// +stateify savable
type FilesystemType struct{}
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// Name implements vfs.FilesystemType.Name.
func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -73,7 +74,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
cgroups = data.Cgroups
}
- _, dentry := procfs.newTasksInode(k, pidns, cgroups)
+ inode := procfs.newTasksInode(k, pidns, cgroups)
+ var dentry kernfs.Dentry
+ dentry.Init(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
@@ -94,12 +97,9 @@ type dynamicInode interface {
Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
}
-func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
+ return inode
}
// +stateify savable
@@ -114,8 +114,8 @@ func newStaticFile(data string) *staticFile {
return &staticFile{StaticData: vfs.StaticData{Data: data}}
}
-func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
- return kernfs.NewStaticDir(creds, devMajor, devMinor, ino, perm, children, kernfs.GenericDirectoryFDOptions{
+func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
+ return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 47ecd941c..bad2fab4f 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -32,10 +32,11 @@ import (
// +stateify savable
type subtasksInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
subtasksInodeRefs
@@ -49,7 +50,7 @@ type subtasksInode struct {
var _ kernfs.Inode = (*subtasksInode)(nil)
-func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *kernfs.Dentry {
+func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode {
subInode := &subtasksInode{
fs: fs,
task: task,
@@ -62,14 +63,11 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
subInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: subInode, owner: task}
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
-
- return dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
tid, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -82,10 +80,10 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry
if subTask.ThreadGroup() != i.task.ThreadGroup() {
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers), nil
+ return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers)
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
@@ -186,6 +184,6 @@ func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credential
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *subtasksInode) DecRef(context.Context) {
- i.subtasksInodeRefs.DecRef(i.Destroy)
+func (i *subtasksInode) DecRef(ctx context.Context) {
+ i.subtasksInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 1f99183eb..b63a4eca0 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -35,8 +35,8 @@ type taskInode struct {
implStatFS
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
taskInodeRefs
@@ -47,40 +47,44 @@ type taskInode struct {
var _ kernfs.Inode = (*taskInode)(nil)
-func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
- // TODO(gvisor.dev/issue/164): Fail with ESRCH if task exited.
- contents := map[string]*kernfs.Dentry{
- "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}),
- "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+
+ contents := map[string]kernfs.Inode{
+ "auxv": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
"comm": fs.newComm(task, fs.NextIno(), 0444),
- "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "cwd": fs.newCwdSymlink(task, fs.NextIno()),
+ "environ": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
"exe": fs.newExeSymlink(task, fs.NextIno()),
"fd": fs.newFDDirInode(task),
"fdinfo": fs.newFDInfoDirInode(task),
- "gid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
- "io": fs.newTaskOwnedFile(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
- "maps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mapsData{task: task}),
- "mountinfo": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
- "mounts": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountsData{task: task}),
+ "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}),
"net": fs.newTaskNetDir(task),
- "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]kernfs.Inode{
"net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"),
"pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"),
"user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"),
}),
- "oom_score": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newStaticFile("0\n")),
- "oom_score_adj": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
- "smaps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &smapsData{task: task}),
- "stat": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
- "statm": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statmData{task: task}),
- "status": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
- "uid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ "oom_score": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &smapsData{task: task}),
+ "stat": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statmData{task: task}),
+ "status": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers)
}
if len(cgroupControllers) > 0 {
- contents["cgroup"] = fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ contents["cgroup"] = fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
}
taskInode := &taskInode{task: task}
@@ -89,17 +93,15 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := taskInode.OrderedChildren.Populate(dentry, contents)
+ links := taskInode.OrderedChildren.Populate(contents)
taskInode.IncLinks(links)
- return dentry
+ return inode, nil
}
-// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// Valid implements kernfs.Inode.Valid. This inode remains valid as long
// as the task is still running. When it's dead, another tasks with the same
// PID could replace it.
func (i *taskInode) Valid(ctx context.Context) bool {
@@ -123,8 +125,8 @@ func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, v
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *taskInode) DecRef(context.Context) {
- i.taskInodeRefs.DecRef(i.Destroy)
+func (i *taskInode) DecRef(ctx context.Context) {
+ i.taskInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
@@ -140,34 +142,23 @@ type taskOwnedInode struct {
var _ kernfs.Inode = (*taskOwnedInode)(nil)
-func (fs *filesystem) newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
- taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
+ return &taskOwnedInode{Inode: inode, owner: task}
}
-func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
- dir := &kernfs.StaticDirectory{}
-
+func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
- dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, kernfs.GenericDirectoryFDOptions{
- SeekEnd: kernfs.SeekEndZero,
- })
- dir.EnableLeakCheck()
-
- inode := &taskOwnedInode{Inode: dir, owner: task}
- d := &kernfs.Dentry{}
- d.Init(inode)
+ fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero}
+ dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
- dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := dir.OrderedChildren.Populate(d, children)
- dir.IncLinks(links)
+ return &taskOwnedInode{Inode: dir, owner: task}
+}
- return d
+func (i *taskOwnedInode) Valid(ctx context.Context) bool {
+ return i.owner.ExitState() != kernel.TaskExitDead && i.Inode.Valid(ctx)
}
// Stat implements kernfs.Inode.Stat.
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 0866cea2b..2c80ac5c2 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -63,7 +63,7 @@ type fdDir struct {
produceSymlink bool
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
@@ -109,16 +109,17 @@ type fdDirInode struct {
fdDir
fdDirInodeRefs
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
}
var _ kernfs.Inode = (*fdDirInode)(nil)
-func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
inode := &fdDirInode{
fdDir: fdDir{
fs: fs,
@@ -128,16 +129,17 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
}
inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
-
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ return inode
+}
- return dentry
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
+func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *fdDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -183,8 +185,8 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *fdDirInode) DecRef(context.Context) {
- i.fdDirInodeRefs.DecRef(i.Destroy)
+func (i *fdDirInode) DecRef(ctx context.Context) {
+ i.fdDirInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
@@ -202,16 +204,13 @@ type fdSymlink struct {
var _ kernfs.Inode = (*fdSymlink)(nil)
-func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kernfs.Inode {
inode := &fdSymlink{
task: task,
fd: fd,
}
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
@@ -236,6 +235,11 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen
return vd, "", nil
}
+// Valid implements kernfs.Inode.Valid.
+func (s *fdSymlink) Valid(ctx context.Context) bool {
+ return taskFDExists(ctx, s.task, s.fd)
+}
+
// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
//
// +stateify savable
@@ -243,16 +247,17 @@ type fdInfoDirInode struct {
fdDir
fdInfoDirInodeRefs
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
}
var _ kernfs.Inode = (*fdInfoDirInode)(nil)
-func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
inode := &fdInfoDirInode{
fdDir: fdDir{
fs: fs,
@@ -261,16 +266,12 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
}
inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
-
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
-
- return dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -283,7 +284,12 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentr
task: i.task,
fd: fd,
}
- return i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data), nil
+ return i.fs.newTaskOwnedInode(i.task, i.fs.NextIno(), 0444, data), nil
+}
+
+// IterDirents implements Inode.IterDirents.
+func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
}
// Open implements kernfs.Inode.Open.
@@ -298,8 +304,8 @@ func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *ker
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *fdInfoDirInode) DecRef(context.Context) {
- i.fdInfoDirInodeRefs.DecRef(i.Destroy)
+func (i *fdInfoDirInode) DecRef(ctx context.Context) {
+ i.fdInfoDirInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
@@ -328,3 +334,8 @@ func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "flags:\t0%o\n", flags)
return nil
}
+
+// Valid implements kernfs.Inode.Valid.
+func (d *fdInfoData) Valid(ctx context.Context) bool {
+ return taskFDExists(ctx, d.task, d.fd)
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index b81c8279e..79f8b7e9f 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -247,13 +247,10 @@ type commInode struct {
task *kernel.Task
}
-func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
inode := &commInode{task: task}
inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
@@ -658,29 +655,30 @@ type exeSymlink struct {
var _ kernfs.Inode = (*exeSymlink)(nil)
-func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &exeSymlink{task: task}
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
// Readlink implements kernfs.Inode.Readlink.
func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
- if !kernel.ContextCanTrace(ctx, s.task, false) {
- return "", syserror.EACCES
- }
-
- // Pull out the executable for /proc/[pid]/exe.
- exec, err := s.executable()
+ exec, _, err := s.Getlink(ctx, nil)
if err != nil {
return "", err
}
defer exec.DecRef(ctx)
- return exec.PathnameWithDeleted(ctx), nil
+ root := vfs.RootFromContext(ctx)
+ if !root.Ok() {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ vfsObj := exec.Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, exec)
+ return name, nil
}
// Getlink implements kernfs.Inode.Getlink.
@@ -688,23 +686,12 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent
if !kernel.ContextCanTrace(ctx, s.task, false) {
return vfs.VirtualDentry{}, "", syserror.EACCES
}
-
- exec, err := s.executable()
- if err != nil {
- return vfs.VirtualDentry{}, "", err
- }
- defer exec.DecRef(ctx)
-
- vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
- vd.IncRef()
- return vd, "", nil
-}
-
-func (s *exeSymlink) executable() (file fsbridge.File, err error) {
if err := checkTaskState(s.task); err != nil {
- return nil, err
+ return vfs.VirtualDentry{}, "", err
}
+ var err error
+ var exec fsbridge.File
s.task.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
@@ -715,12 +702,75 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) {
// The MemoryManager may be destroyed, in which case
// MemoryManager.destroy will simply set the executable to nil
// (with locks held).
- file = mm.Executable()
- if file == nil {
+ exec = mm.Executable()
+ if exec == nil {
err = syserror.ESRCH
}
})
- return
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef(ctx)
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+// cwdSymlink is an symlink for the /proc/[pid]/cwd file.
+//
+// +stateify savable
+type cwdSymlink struct {
+ implStatFS
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*cwdSymlink)(nil)
+
+func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
+ inode := &cwdSymlink{task: task}
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ return inode
+}
+
+// Readlink implements kernfs.Inode.Readlink.
+func (s *cwdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
+ cwd, _, err := s.Getlink(ctx, nil)
+ if err != nil {
+ return "", err
+ }
+ defer cwd.DecRef(ctx)
+
+ root := vfs.RootFromContext(ctx)
+ if !root.Ok() {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ vfsObj := cwd.Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, cwd)
+ return name, nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (s *cwdSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ cwd := s.task.FSContext().WorkingDirectoryVFS2()
+ if !cwd.Ok() {
+ // It could have raced with process deletion.
+ return vfs.VirtualDentry{}, "", syserror.ESRCH
+ }
+ return cwd, "", nil
}
// mountInfoData is used to implement /proc/[pid]/mountinfo.
@@ -792,7 +842,7 @@ type namespaceSymlink struct {
task *kernel.Task
}
-func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) kernfs.Inode {
// Namespace symlinks should contain the namespace name and the inode number
// for the namespace instance, so for example user:[123456]. We currently fake
// the inode number by sticking the symlink inode in its place.
@@ -803,9 +853,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
+ return taskInode
}
// Readlink implements kernfs.Inode.Readlink.
@@ -823,11 +871,12 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
}
// Create a synthetic inode to represent the namespace.
+ fs := mnt.Filesystem().Impl().(*filesystem)
dentry := &kernfs.Dentry{}
- dentry.Init(&namespaceInode{})
+ dentry.Init(&fs.Filesystem, &namespaceInode{})
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
- vd.IncRef()
- dentry.DecRef(ctx)
+ // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1.
+ mnt.IncRef()
return vd, "", nil
}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index e7f748655..3425e8698 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -37,12 +37,12 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode {
k := task.Kernel()
pidns := task.PIDNamespace()
root := auth.NewRootCredentials(pidns.UserNamespace())
- var contents map[string]*kernfs.Dentry
+ var contents map[string]kernfs.Inode
if stack := task.NetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
@@ -56,34 +56,34 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
// TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
// network namespace.
- contents = map[string]*kernfs.Dentry{
- "dev": fs.newDentry(root, fs.NextIno(), 0444, &netDevData{stack: stack}),
- "snmp": fs.newDentry(root, fs.NextIno(), 0444, &netSnmpData{stack: stack}),
+ contents = map[string]kernfs.Inode{
+ "dev": fs.newInode(root, 0444, &netDevData{stack: stack}),
+ "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, if the file contains a header the stub is just the header
// otherwise it is an empty file.
- "arp": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(arp)),
- "netlink": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(netlink)),
- "netstat": fs.newDentry(root, fs.NextIno(), 0444, &netStatData{}),
- "packet": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(packet)),
- "protocols": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(protocols)),
+ "arp": fs.newInode(root, 0444, newStaticFile(arp)),
+ "netlink": fs.newInode(root, 0444, newStaticFile(netlink)),
+ "netstat": fs.newInode(root, 0444, &netStatData{}),
+ "packet": fs.newInode(root, 0444, newStaticFile(packet)),
+ "protocols": fs.newInode(root, 0444, newStaticFile(protocols)),
// Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
// high res timer ticks per sec (ClockGetres returns 1ns resolution).
- "psched": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(psched)),
- "ptype": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(ptype)),
- "route": fs.newDentry(root, fs.NextIno(), 0444, &netRouteData{stack: stack}),
- "tcp": fs.newDentry(root, fs.NextIno(), 0444, &netTCPData{kernel: k}),
- "udp": fs.newDentry(root, fs.NextIno(), 0444, &netUDPData{kernel: k}),
- "unix": fs.newDentry(root, fs.NextIno(), 0444, &netUnixData{kernel: k}),
+ "psched": fs.newInode(root, 0444, newStaticFile(psched)),
+ "ptype": fs.newInode(root, 0444, newStaticFile(ptype)),
+ "route": fs.newInode(root, 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}),
+ "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}),
+ "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}),
}
if stack.SupportsIPv6() {
- contents["if_inet6"] = fs.newDentry(root, fs.NextIno(), 0444, &ifinet6{stack: stack})
- contents["ipv6_route"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(""))
- contents["tcp6"] = fs.newDentry(root, fs.NextIno(), 0444, &netTCP6Data{kernel: k})
- contents["udp6"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(upd6))
+ contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6))
}
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index d8f5dd509..3259c3732 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -38,10 +38,11 @@ const (
// +stateify savable
type tasksInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
kernfs.OrderedChildren
tasksInodeRefs
@@ -52,8 +53,6 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- selfSymlink *kernfs.Dentry
- threadSelfSymlink *kernfs.Dentry
// cgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
@@ -63,52 +62,53 @@ type tasksInode struct {
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
+func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
- contents := map[string]*kernfs.Dentry{
- "cpuinfo": fs.newDentry(root, fs.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
- "filesystems": fs.newDentry(root, fs.NextIno(), 0444, &filesystemsData{}),
- "loadavg": fs.newDentry(root, fs.NextIno(), 0444, &loadavgData{}),
+ contents := map[string]kernfs.Inode{
+ "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newInode(root, 0444, &filesystemsData{}),
+ "loadavg": fs.newInode(root, 0444, &loadavgData{}),
"sys": fs.newSysDir(root, k),
- "meminfo": fs.newDentry(root, fs.NextIno(), 0444, &meminfoData{}),
+ "meminfo": fs.newInode(root, 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
"net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
- "stat": fs.newDentry(root, fs.NextIno(), 0444, &statData{}),
- "uptime": fs.newDentry(root, fs.NextIno(), 0444, &uptimeData{}),
- "version": fs.newDentry(root, fs.NextIno(), 0444, &versionData{}),
+ "stat": fs.newInode(root, 0444, &statData{}),
+ "uptime": fs.newInode(root, 0444, &uptimeData{}),
+ "version": fs.newInode(root, 0444, &versionData{}),
}
inode := &tasksInode{
pidns: pidns,
fs: fs,
- selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns),
- threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns),
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
-
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := inode.OrderedChildren.Populate(dentry, contents)
+ links := inode.OrderedChildren.Populate(contents)
inode.IncLinks(links)
- return inode, dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
- // Try to lookup a corresponding task.
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
+ // Check if a static entry was looked up.
+ if d, err := i.OrderedChildren.Lookup(ctx, name); err == nil {
+ return d, nil
+ }
+
+ // Not a static entry. Try to lookup a corresponding task.
tid, err := strconv.ParseUint(name, 10, 64)
if err != nil {
+ root := auth.NewRootCredentials(i.pidns.UserNamespace())
// If it failed to parse, check if it's one of the special handled files.
switch name {
case selfName:
- return i.selfSymlink, nil
+ return i.newSelfSymlink(root), nil
case threadSelfName:
- return i.threadSelfSymlink, nil
+ return i.newThreadSelfSymlink(root), nil
}
return nil, syserror.ENOENT
}
@@ -118,10 +118,10 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, e
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers), nil
+ return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers)
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
@@ -229,8 +229,8 @@ func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.St
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *tasksInode) DecRef(context.Context) {
- i.tasksInodeRefs.DecRef(i.Destroy)
+func (i *tasksInode) DecRef(ctx context.Context) {
+ i.tasksInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// staticFileSetStat implements a special static file that allows inode
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index f268c59b0..07c27cdd9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -43,13 +43,10 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
- inode := &selfSymlink{pidns: pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+ inode := &selfSymlink{pidns: i.pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ return inode
}
func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
@@ -87,13 +84,10 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
- inode := &threadSelfSymlink{pidns: pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+ inode := &threadSelfSymlink{pidns: i.pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ return inode
}
func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 3312b0418..95420368d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -40,93 +40,93 @@ const (
)
// newSysDir returns the dentry corresponding to /proc/sys directory.
-func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
- return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "kernel": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}),
- "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)),
+func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ return fs.newStaticDir(root, map[string]kernfs.Inode{
+ "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "hostname": fs.newInode(root, 0444, &hostnameData{}),
+ "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)),
}),
- "vm": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}),
- "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "vm": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")),
}),
"net": fs.newSysNetDir(root, k),
})
}
// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
-func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
- var contents map[string]*kernfs.Dentry
+func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ var contents map[string]kernfs.Inode
// TODO(gvisor.dev/issue/1833): Support for using the network stack in the
// network namespace of the calling process.
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
- contents = map[string]*kernfs.Dentry{
- "ipv4": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}),
+ contents = map[string]kernfs.Inode{
+ "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("16000 65535")),
- "ip_local_reserved_ports": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "ipfrag_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("30")),
- "ip_nonlocal_bind": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "ip_no_pmtu_disc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")),
+ "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")),
// tcp_allowed_congestion_control tell the user what they are able to
// do as an unprivledged process so we leave it empty.
- "tcp_allowed_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "tcp_available_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
- "tcp_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
// Many of the following stub files are features netstack doesn't
// support. The unsupported features return "0" to indicate they are
// disabled.
- "tcp_base_mss": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1280")),
- "tcp_dsack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_early_retrans": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fastopen": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fastopen_key": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "tcp_invalid_ratelimit": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_intvl": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_probes": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("7200")),
- "tcp_mtu_probing": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_no_metrics_save": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_probe_interval": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_probe_threshold": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_retries1": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
- "tcp_retries2": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("15")),
- "tcp_rfc1337": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_slow_start_after_idle": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_synack_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
- "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
- "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")),
}),
- "core": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")),
- "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")),
- "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
- "optmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "rmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "rmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "somaxconn": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("128")),
- "wmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "wmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "core": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newInode(root, 0444, newStaticFile("10")),
+ "message_cost": fs.newInode(root, 0444, newStaticFile("5")),
+ "optmem_max": fs.newInode(root, 0444, newStaticFile("0")),
+ "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
+ "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
+ "somaxconn": fs.newInode(root, 0444, newStaticFile("128")),
+ "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
+ "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
}),
}
}
- return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents)
+ return fs.newStaticDir(root, contents)
}
// mmapMinAddrData implements vfs.DynamicBytesSource for
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index f693f9060..2582ababd 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -67,6 +67,7 @@ var (
taskStaticFiles = map[string]testutil.DirentType{
"auxv": linux.DT_REG,
"cgroup": linux.DT_REG,
+ "cwd": linux.DT_LNK,
"cmdline": linux.DT_REG,
"comm": linux.DT_REG,
"environ": linux.DT_REG,
@@ -108,9 +109,12 @@ func setup(t *testing.T) *testutil.System {
if err != nil {
t.Fatalf("NewMountNamespace(): %v", err)
}
+ root := mntns.Root()
+ root.IncRef()
+ defer root.DecRef(ctx)
pop := &vfs.PathOperation{
- Root: mntns.Root(),
- Start: mntns.Root(),
+ Root: root,
+ Start: root,
Path: fspath.Parse("/proc"),
}
if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
@@ -118,8 +122,8 @@ func setup(t *testing.T) *testutil.System {
}
pop = &vfs.PathOperation{
- Root: mntns.Root(),
- Start: mntns.Root(),
+ Root: root,
+ Start: root,
Path: fspath.Parse("/proc"),
}
mntOpts := &vfs.MountOptions{
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
index 067c1657f..adb610213 100644
--- a/pkg/sentry/fsimpl/signalfd/BUILD
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/kernel",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index bf11b425a..10f1452ef 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -16,7 +16,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -95,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -105,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index 29e5371d6..cf91ea36c 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -46,6 +46,9 @@ func (filesystemType) Name() string {
return "sockfs"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -114,6 +117,6 @@ func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
d := &kernfs.Dentry{}
- d.Init(i)
+ d.Init(&fs.Filesystem, i)
return d.VFSDentry()
}
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index b75d70ae6..94366d429 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -27,12 +27,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) *kernfs.Dentry {
+func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
k := &kcovInode{}
k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
- d := &kernfs.Dentry{}
- d.Init(k)
- return d
+ return k
}
// kcovInode implements kernfs.Inode.
@@ -104,7 +102,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
func (fd *kcovFD) Release(ctx context.Context) {
// kcov instances have reference counts in Linux, but this seems sufficient
// for our purposes.
- fd.kcov.Reset()
+ fd.kcov.Clear()
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1568c581f..1ad679830 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -52,6 +52,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
@@ -64,15 +67,15 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
- root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"block": fs.newDir(creds, defaultSysDirMode, nil),
"bus": fs.newDir(creds, defaultSysDirMode, nil),
- "class": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"power_supply": fs.newDir(creds, defaultSysDirMode, nil),
}),
"dev": fs.newDir(creds, defaultSysDirMode, nil),
- "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
- "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"cpu": cpuDir(ctx, fs, creds),
}),
}),
@@ -82,13 +85,15 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
"module": fs.newDir(creds, defaultSysDirMode, nil),
"power": fs.newDir(creds, defaultSysDirMode, nil),
})
- return fs.VFSFilesystem(), root.VFSDentry(), nil
+ var rootD kernfs.Dentry
+ rootD.Init(&fs.Filesystem, root)
+ return fs.VFSFilesystem(), rootD.VFSDentry(), nil
}
-func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry {
+func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
- children := map[string]*kernfs.Dentry{
+ children := map[string]kernfs.Inode{
"online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
"possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
"present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
@@ -99,14 +104,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernf
return fs.newDir(creds, defaultSysDirMode, children)
}
-func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry {
+func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
// If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs
// should be mounted at debug/, but for our purposes, it is sufficient to
// keep it in sys.
- var children map[string]*kernfs.Dentry
+ var children map[string]kernfs.Inode
if coverage.KcovAvailable() {
- children = map[string]*kernfs.Dentry{
- "debug": fs.newDir(creds, linux.FileMode(0700), map[string]*kernfs.Dentry{
+ children = map[string]kernfs.Inode{
+ "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{
"kcov": fs.newKcovFile(ctx, creds),
}),
}
@@ -125,27 +130,23 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type dir struct {
dirRefs
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeTemporary
kernfs.OrderedChildren
locks vfs.FileLocks
-
- dentry kernfs.Dentry
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
d := &dir{}
d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
d.EnableLeakCheck()
- d.dentry.Init(d)
-
- d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents))
-
- return &d.dentry
+ d.IncLinks(d.OrderedChildren.Populate(contents))
+ return d
}
// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
@@ -165,8 +166,8 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry
}
// DecRef implements kernfs.Inode.DecRef.
-func (d *dir) DecRef(context.Context) {
- d.dirRefs.DecRef(d.Destroy)
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
// StatFS implements kernfs.Inode.StatFS.
@@ -190,12 +191,10 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
c := &cpuFile{maxCores: maxCores}
c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
- d := &kernfs.Dentry{}
- d.Init(c)
- return d
+ return c
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 568132121..1a8525b06 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -46,16 +46,18 @@ type System struct {
// NewSystem constructs a System.
//
-// Precondition: Caller must hold a reference on MntNs, whose ownership
+// Precondition: Caller must hold a reference on mns, whose ownership
// is transferred to the new System.
func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
+ root := mns.Root()
+ root.IncRef()
s := &System{
t: t,
Ctx: ctx,
Creds: auth.CredentialsFromContext(ctx),
VFS: v,
MntNs: mns,
- Root: mns.Root(),
+ Root: root,
}
return s
}
@@ -254,10 +256,10 @@ func (d *DirentCollector) Contains(name string, typ uint8) error {
defer d.mu.Unlock()
dirent, ok := d.dirents[name]
if !ok {
- return fmt.Errorf("No dirent named %q found", name)
+ return fmt.Errorf("no dirent named %q found", name)
}
if dirent.Type != typ {
- return fmt.Errorf("Dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
+ return fmt.Errorf("dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
}
return nil
}
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 5209a17af..3cc63e732 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -193,6 +193,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
// Create nested directories with given depth.
root := mntns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
vd := root
vd.IncRef()
@@ -387,6 +388,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
// Create the mount point.
root := mntns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
pop := vfs.PathOperation{
Root: root,
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index be29a2363..2f856ce36 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -165,6 +165,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create the pipe.
root := mntns.Root()
+ root.IncRef()
pop := vfs.PathOperation{
Root: root,
Start: root,
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index a199eb33d..ce4e3eda7 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -293,7 +293,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.
optional.End = pgend
}
- cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ cerr := rf.data.Fill(ctx, required, optional, rf.size, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
// Newly-allocated pages are zeroed, so we don't need to do anything.
return dsts.NumBytes(), nil
})
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index cefec8fde..e2a0aac69 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -74,6 +74,8 @@ type filesystem struct {
mu sync.RWMutex `state:"nosave"`
nextInoMinusOne uint64 // accessed using atomic memory operations
+
+ root *dentry
}
// Name implements vfs.FilesystemType.Name.
@@ -81,6 +83,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// FilesystemOpts is used to pass configuration data to tmpfs.
//
// +stateify savable
@@ -194,6 +199,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.vfsfs.DecRef(ctx)
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
}
+ fs.root = root
return &fs.vfsfs, &root.vfsd, nil
}
@@ -205,6 +211,37 @@ func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *au
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.mu.Lock()
+ if fs.root.inode.isDir() {
+ fs.root.releaseChildrenLocked(ctx)
+ }
+ fs.mu.Unlock()
+}
+
+// releaseChildrenLocked is called on the mount point by filesystem.Release() to
+// destroy all objects in the mount. It performs a depth-first walk of the
+// filesystem and "unlinks" everything by decrementing link counts
+// appropriately. There should be no open file descriptors when this is called,
+// so each inode should only have one outstanding reference that is removed once
+// its link count hits zero.
+//
+// Note that we do not update filesystem state precisely while tearing down (for
+// instance, the child maps are ignored)--we only care to remove all remaining
+// references so that every filesystem object gets destroyed. Also note that we
+// do not need to trigger DecRef on the mount point itself or any child mount;
+// these are taken care of by the destructor of the enclosing MountNamespace.
+//
+// Precondition: filesystem.mu is held.
+func (d *dentry) releaseChildrenLocked(ctx context.Context) {
+ dir := d.inode.impl.(*directory)
+ for _, child := range dir.childMap {
+ if child.inode.isDir() {
+ child.releaseChildrenLocked(ctx)
+ child.inode.decLinksLocked(ctx) // link for child/.
+ dir.inode.decLinksLocked(ctx) // link for child/..
+ }
+ child.inode.decLinksLocked(ctx) // link for child
+ }
}
// immutable
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
index 99c8e3c0f..fc5323abc 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -46,6 +46,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
}
root := mntns.Root()
+ root.IncRef()
return vfsObj, root, func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index bc8e38431..0ca750281 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
licenses(["notice"])
@@ -26,3 +26,22 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "verity_test",
+ srcs = [
+ "verity_test.go",
+ ],
+ library = ":verity",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 26b117ca4..3b3c8725f 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -20,6 +20,7 @@ import (
"io"
"strconv"
"strings"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -155,11 +156,10 @@ afterSymlink:
return child, nil
}
-// verifyChild verifies the root hash of child against the already verified
-// root hash of the parent to ensure the child is expected. verifyChild
-// triggers a sentry panic if unexpected modifications to the file system are
-// detected. In noCrashOnVerificationFailure mode it returns a syserror
-// instead.
+// verifyChild verifies the hash of child against the already verified hash of
+// the parent to ensure the child is expected. verifyChild triggers a sentry
+// panic if unexpected modifications to the file system are detected. In
+// noCrashOnVerificationFailure mode it returns a syserror instead.
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
// TODO(b/166474175): Investigate all possible errors returned in this
// function, and make sure we differentiate all errors that indicate unexpected
@@ -174,12 +174,12 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, err
}
- verityMu.RLock()
- defer verityMu.RUnlock()
+ fs.verityMu.RLock()
+ defer fs.verityMu.RUnlock()
// Read the offset of the child from the extended attributes of the
// corresponding Merkle tree file.
- // This is the offset of the root hash for child in its parent's Merkle
- // tree file.
+ // This is the offset of the hash for child in its parent's Merkle tree
+ // file.
off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{
Root: child.lowerMerkleVD,
Start: child.lowerMerkleVD,
@@ -204,7 +204,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
- // Open parent Merkle tree file to read and verify child's root hash.
+ // Open parent Merkle tree file to read and verify child's hash.
parentMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: parent.lowerMerkleVD,
Start: parent.lowerMerkleVD,
@@ -223,7 +223,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// dataSize is the size of raw data for the Merkle tree. For a file,
// dataSize is the size of the whole file. For a directory, dataSize is
- // the size of all its children's root hashes.
+ // the size of all its children's hashes.
dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
Size: sizeOfStringInt32,
@@ -251,32 +251,152 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Ctx: ctx,
}
+ parentStat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parent.lowerVD,
+ Start: parent.lowerVD,
+ }, &vfs.StatOptions{})
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return nil, err
+ }
+
// Since we are verifying against a directory Merkle tree, buf should
- // contain the root hash of the children in the parent Merkle tree when
+ // contain the hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
+ if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ ReadOffset: int64(offset),
+ ReadSize: int64(merkletree.DigestSize()),
+ Expected: parent.hash,
+ DataAndTreeInSameFile: true,
+ }); err != nil && err != io.EOF {
return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
- // Cache child root hash when it's verified the first time.
- if len(child.rootHash) == 0 {
- child.rootHash = buf.Bytes()
+ // Cache child hash when it's verified the first time.
+ if len(child.hash) == 0 {
+ child.hash = buf.Bytes()
}
return child, nil
}
+// verifyStat verifies the stat against the verified hash. The mode/uid/gid of
+// the file is cached after verified.
+func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+
+ // Get the path to the child dentry. This is only used to provide path
+ // information in failure case.
+ childPath, err := vfsObj.PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD)
+ if err != nil {
+ return err
+ }
+
+ fs.verityMu.RLock()
+ defer fs.verityMu.RUnlock()
+
+ fd, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err == syserror.ENOENT {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ if err == syserror.ENODATA {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ size, err := strconv.Atoi(merkleSize)
+ if err != nil {
+ return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ }
+
+ fdReader := vfs.FileReadWriteSeeker{
+ FD: fd,
+ Ctx: ctx,
+ }
+
+ var buf bytes.Buffer
+ params := &merkletree.VerifyParams{
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ ReadOffset: 0,
+ // Set read size to 0 so only the metadata is verified.
+ ReadSize: 0,
+ Expected: d.hash,
+ DataAndTreeInSameFile: false,
+ }
+ if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
+ params.DataAndTreeInSameFile = true
+ }
+
+ if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
+ return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ }
+ d.mode = uint32(stat.Mode)
+ d.uid = stat.UID
+ d.gid = stat.GID
+ return nil
+}
+
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
// If enabling verification on files/directories is not allowed
// during runtime, all cached children are already verified. If
// runtime enable is allowed and the parent directory is
- // enabled, we should verify the child root hash here because
- // it may be cached before enabled.
- if fs.allowRuntimeEnable && len(parent.rootHash) != 0 {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
- return nil, err
+ // enabled, we should verify the child hash here because it may
+ // be cached before enabled.
+ if fs.allowRuntimeEnable {
+ if isEnabled(parent) {
+ if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ return nil, err
+ }
+ }
+ if isEnabled(child) {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.lowerVD,
+ Start: child.lowerVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ return nil, err
+ }
}
}
return child, nil
@@ -360,9 +480,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// file open and ready to use.
// This may cause empty and unused Merkle tree files in
// allowRuntimeEnable mode, if they are never enabled. This
- // does not affect verification, as we rely on cached root hash
- // to decide whether to perform verification, not the existence
- // of the Merkle tree file. Also, those Merkle tree files are
+ // does not affect verification, as we rely on cached hash to
+ // decide whether to perform verification, not the existence of
+ // the Merkle tree file. Also, those Merkle tree files are
// always hidden and cannot be accessed by verity fs users.
if fs.allowRuntimeEnable {
childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
@@ -426,20 +546,25 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
child.parent = parent
child.name = name
- // TODO(b/162788573): Verify child metadata.
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
- // Verify child root hash. This should always be performed unless in
+ // Verify child hash. This should always be performed unless in
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
- if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) {
+ if isEnabled(parent) {
if _, err := fs.verifyChild(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
+ if isEnabled(child) {
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ child.destroyLocked(ctx)
+ return nil, err
+ }
+ }
return child, nil
}
@@ -693,22 +818,24 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// be called if a verity FD is created successfully.
defer merkleWriter.DecRef(ctx)
- parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.parent.lowerMerkleVD,
- Start: d.parent.lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_WRONLY | linux.O_APPEND,
- })
- if err != nil {
- if err == syserror.ENOENT {
- parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ if d.parent != nil {
+ parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.parent.lowerMerkleVD,
+ Start: d.parent.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ }
+ return nil, err
}
- return nil, err
+ // parentMerkleWriter is cleaned up if any error occurs. IncRef
+ // will be called if a verity FD is created successfully.
+ defer parentMerkleWriter.DecRef(ctx)
}
- // parentMerkleWriter is cleaned up if any error occurs. IncRef
- // will be called if a verity FD is created successfully.
- defer parentMerkleWriter.DecRef(ctx)
}
fd := &fileDescription{
@@ -769,6 +896,8 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
// StatAt implements vfs.FilesystemImpl.StatAt.
+// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should
+// be verified.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
@@ -786,6 +915,11 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(d) {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 3129f290d..70034280b 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -49,12 +49,12 @@ const Name = "verity"
const merklePrefix = ".merkle.verity."
// merkleoffsetInParentXattr is the extended attribute name specifying the
-// offset of child root hash in its parent's Merkle tree.
+// offset of child hash in its parent's Merkle tree.
const merkleOffsetInParentXattr = "user.merkle.offset"
// merkleSizeXattr is the extended attribute name specifying the size of data
// hashed by the corresponding Merkle tree. For a file, it's the size of the
-// whole file. For a directory, it's the size of all its children's root hashes.
+// whole file. For a directory, it's the size of all its children's hashes.
const merkleSizeXattr = "user.merkle.size"
// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
@@ -68,11 +68,6 @@ const sizeOfStringInt32 = 10
// flag.
var noCrashOnVerificationFailure bool
-// verityMu synchronizes enabling verity files, protects files or directories
-// from being enabled by different threads simultaneously. It also ensures that
-// verity does not access files that are being enabled.
-var verityMu sync.RWMutex
-
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
@@ -106,6 +101,17 @@ type filesystem struct {
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
renameMu sync.RWMutex `state:"nosave"`
+
+ // verityMu synchronizes enabling verity files, protects files or
+ // directories from being enabled by different threads simultaneously.
+ // It also ensures that verity does not access files that are being
+ // enabled.
+ //
+ // Also, the directory Merkle trees depends on the generated trees of
+ // its children. So they shouldn't be enabled the same time. This lock
+ // is for the whole file system to ensure that no more than one file is
+ // enabled the same time.
+ verityMu sync.RWMutex
}
// InternalFilesystemOptions may be passed as
@@ -142,6 +148,17 @@ func (FilesystemType) Name() string {
return Name
}
+// isEnabled checks whether the target is enabled with verity features. It
+// should always be true if runtime enable is not allowed. In runtime enable
+// mode, it returns true if the target has been enabled with
+// ioctl(FS_IOC_ENABLE_VERITY).
+func isEnabled(d *dentry) bool {
+ return !d.fs.allowRuntimeEnable || len(d.hash) != 0
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
@@ -245,13 +262,18 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
- // TODO(b/162788573): Verify Metadata.
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
+ d.hash = make([]byte, len(iopts.RootHash))
- d.rootHash = make([]byte, len(iopts.RootHash))
- copy(d.rootHash, iopts.RootHash)
+ if !fs.allowRuntimeEnable {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return nil, nil, err
+ }
+ }
+
+ copy(d.hash, iopts.RootHash)
d.vfsd.Init(d)
fs.rootDentry = d
@@ -303,8 +325,8 @@ type dentry struct {
// in the underlying file system.
lowerMerkleVD vfs.VirtualDentry
- // rootHash is the rootHash for the current file or directory.
- rootHash []byte
+ // hash is the calculated hash for the current file or directory.
+ hash []byte
}
// newDentry creates a new dentry representing the given verity file. The
@@ -488,6 +510,11 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(fd.d) {
+ if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
@@ -498,11 +525,11 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// generateMerkle generates a Merkle tree file for fd. If fd points to a file
-// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root
-// hash of the generated Merkle tree and the data size is returned.
-// If fd points to a regular file, the data is the content of the file. If fd
-// points to a directory, the data is all root hahes of its children, written
-// to the Merkle tree file.
+// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
+// of the generated Merkle tree and the data size is returned. If fd points to
+// a regular file, the data is the content of the file. If fd points to a
+// directory, the data is all hahes of its children, written to the Merkle tree
+// file.
func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
fdReader := vfs.FileReadWriteSeeker{
FD: fd.lowerFD,
@@ -516,8 +543,10 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
FD: fd.merkleWriter,
Ctx: ctx,
}
- var rootHash []byte
- var dataSize uint64
+ params := &merkletree.GenerateParams{
+ TreeReader: &merkleReader,
+ TreeWriter: &merkleWriter,
+ }
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
case linux.S_IFREG:
@@ -528,75 +557,90 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
if err != nil {
return nil, 0, err
}
- dataSize = stat.Size
- rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */)
- if err != nil {
- return nil, 0, err
- }
+ params.File = &fdReader
+ params.Size = int64(stat.Size)
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = false
case linux.S_IFDIR:
- // For a directory, generate a Merkle tree based on the root
- // hashes of its children that has already been written to the
- // Merkle tree file.
+ // For a directory, generate a Merkle tree based on the hashes
+ // of its children that has already been written to the Merkle
+ // tree file.
merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{})
if err != nil {
return nil, 0, err
}
- dataSize = merkleStat.Size
- rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */)
+ params.Size = int64(merkleStat.Size)
+
+ stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{})
if err != nil {
return nil, 0, err
}
+
+ params.File = &merkleReader
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = true
default:
// TODO(b/167728857): Investigate whether and how we should
// enable other types of file.
return nil, 0, syserror.EINVAL
}
- return rootHash, dataSize, nil
+ hash, err := merkletree.Generate(params)
+ return hash, uint64(params.Size), err
}
// enableVerity enables verity features on fd by generating a Merkle tree file
-// and stores its root hash in its parent directory's Merkle tree.
-func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+// and stores its hash in its parent directory's Merkle tree.
+func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) {
if !fd.d.fs.allowRuntimeEnable {
return 0, syserror.EPERM
}
- // Lock to prevent other threads performing enable or access the file
- // while it's being enabled.
- verityMu.Lock()
- defer verityMu.Unlock()
+ fd.d.fs.verityMu.Lock()
+ defer fd.d.fs.verityMu.Unlock()
- if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil {
+ // In allowRuntimeEnable mode, the underlying fd and read/write fd for
+ // the Merkle tree file should have all been initialized. For any file
+ // or directory other than the root, the parent Merkle tree file should
+ // have also been initialized.
+ if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
}
- rootHash, dataSize, err := fd.generateMerkle(ctx)
+ hash, dataSize, err := fd.generateMerkle(ctx)
if err != nil {
return 0, err
}
- stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- return 0, err
- }
+ if fd.parentMerkleWriter != nil {
+ stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
- // Write the root hash of fd to the parent directory's Merkle tree
- // file, as it should be part of the parent Merkle tree data.
- // parentMerkleWriter is open with O_APPEND, so it should write
- // directly to the end of the file.
- if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil {
- return 0, err
- }
+ // Write the hash of fd to the parent directory's Merkle tree
+ // file, as it should be part of the parent Merkle tree data.
+ // parentMerkleWriter is open with O_APPEND, so it should write
+ // directly to the end of the file.
+ if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(hash), vfs.WriteOptions{}); err != nil {
+ return 0, err
+ }
- // Record the offset of the root hash of fd in parent directory's
- // Merkle tree file.
- if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
- Name: merkleOffsetInParentXattr,
- Value: strconv.Itoa(int(stat.Size)),
- }); err != nil {
- return 0, err
+ // Record the offset of the hash of fd in parent directory's
+ // Merkle tree file.
+ if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
+ Name: merkleOffsetInParentXattr,
+ Value: strconv.Itoa(int(stat.Size)),
+ }); err != nil {
+ return 0, err
+ }
}
// Record the size of the data being hashed for fd.
@@ -606,22 +650,59 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
}); err != nil {
return 0, err
}
- fd.d.rootHash = append(fd.d.rootHash, rootHash...)
+ fd.d.hash = append(fd.d.hash, hash...)
return 0, nil
}
-func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+// measureVerity returns the hash of fd, saved in verityDigest.
+func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ var metadata linux.DigestMetadata
+
+ // If allowRuntimeEnable is true, an empty fd.d.hash indicates that
+ // verity is not enabled for the file. If allowRuntimeEnable is false,
+ // this is an integrity violation because all files should have verity
+ // enabled, in which case fd.d.hash should be set.
+ if len(fd.d.hash) == 0 {
+ if fd.d.fs.allowRuntimeEnable {
+ return 0, syserror.ENODATA
+ }
+ return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found")
+ }
+
+ // The first part of VerityDigest is the metadata.
+ if _, err := metadata.CopyIn(t, verityDigest); err != nil {
+ return 0, err
+ }
+ if metadata.DigestSize < uint16(len(fd.d.hash)) {
+ return 0, syserror.EOVERFLOW
+ }
+
+ // Populate the output digest size, since DigestSize is both input and
+ // output.
+ metadata.DigestSize = uint16(len(fd.d.hash))
+
+ // First copy the metadata.
+ if _, err := metadata.CopyOut(t, verityDigest); err != nil {
+ return 0, err
+ }
+
+ // Now copy the root hash bytes to the memory after metadata.
+ _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash)
+ return 0, err
+}
+
+func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) {
f := int32(0)
- // All enabled files should store a root hash. This flag is not settable
- // via FS_IOC_SETFLAGS.
- if len(fd.d.rootHash) != 0 {
+ // All enabled files should store a hash. This flag is not settable via
+ // FS_IOC_SETFLAGS.
+ if len(fd.d.hash) != 0 {
f |= linux.FS_VERITY_FL
}
t := kernel.TaskFromContext(ctx)
- addr := args[2].Pointer()
- _, err := primitive.CopyInt32Out(t, addr, f)
+ _, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -629,11 +710,15 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar
func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FS_IOC_ENABLE_VERITY:
- return fd.enableVerity(ctx, uio, args)
+ return fd.enableVerity(ctx, uio)
+ case linux.FS_IOC_MEASURE_VERITY:
+ return fd.measureVerity(ctx, uio, args[2].Pointer())
case linux.FS_IOC_GETFLAGS:
- return fd.getFlags(ctx, uio, args)
+ return fd.verityFlags(ctx, uio, args[2].Pointer())
default:
- return fd.lowerFD.Ioctl(ctx, uio, args)
+ // TODO(b/169682228): Investigate which ioctl commands should
+ // be allowed.
+ return 0, syserror.ENOSYS
}
}
@@ -641,10 +726,12 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
// allowRuntimeEnable mode.
- if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 {
+ if !isEnabled(fd.d) {
return fd.lowerFD.PRead(ctx, dst, offset, opts)
}
+ fd.d.fs.verityMu.RLock()
+ defer fd.d.fs.verityMu.RUnlock()
// dataSize is the size of the whole file.
dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
@@ -678,9 +765,22 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Ctx: ctx,
}
- n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */)
+ n, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ ReadOffset: offset,
+ ReadSize: dst.NumBytes(),
+ Expected: fd.d.hash,
+ DataAndTreeInSameFile: false,
+ })
if err != nil {
- return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
new file mode 100644
index 000000000..e301d35f5
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -0,0 +1,491 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// rootMerkleFilename is the name of the root Merkle tree file.
+const rootMerkleFilename = "root.verity"
+
+// maxDataSize is the maximum data size written to the file for test.
+const maxDataSize = 100000
+
+// newVerityRoot creates a new verity mount, and returns the root. The
+// underlying file system is tmpfs. If the error is not nil, then cleanup
+// should be called when the root is no longer needed.
+func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+ rand.Seed(time.Now().UnixNano())
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: InternalFilesystemOptions{
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ AllowRuntimeEnable: true,
+ NoCrashOnVerificationFailure: true,
+ },
+ },
+ })
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ }
+ root := mntns.Root()
+ root.IncRef()
+ t.Helper()
+ t.Cleanup(func() {
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
+ })
+ return vfsObj, root, nil
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+
+ // Create the file in the underlying file system.
+ lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Generate random data to be written to the file.
+ dataSize := rand.Intn(maxDataSize) + 1
+ data := make([]byte, dataSize)
+ rand.Read(data)
+
+ // Write directly to the underlying FD, since verity FD is read-only.
+ n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ if n != int64(len(data)) {
+ return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data))
+ }
+
+ lowerFD.DecRef(ctx)
+
+ // Now open the verity file descriptor.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular | mode,
+ })
+ return fd, dataSize, err
+}
+
+// corruptRandomBit randomly flips a bit in the file represented by fd.
+func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
+ // Flip a random bit in the underlying file.
+ randomPos := int64(rand.Intn(size))
+ byteToModify := make([]byte, 1)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PRead: %v", err)
+ }
+ byteToModify[0] ^= 1
+ if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PWrite: %v", err)
+ }
+ return nil
+}
+
+// TestOpen ensures that when a file is created, the corresponding Merkle tree
+// file and the root Merkle tree file exist.
+func TestOpen(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
+}
+
+// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
+// succeeds after enabling verity for it.
+func TestReadUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
+}
+
+// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
+// succeeds after enabling verity for it.
+func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
+}
+
+// TestModifiedFileFails ensures that read from a modified verity file fails.
+func TestModifiedFileFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded with modified file")
+ }
+}
+
+// TestModifiedMerkleFails ensures that read from a verity file fails if the
+// corresponding Merkle tree file is modified.
+func TestModifiedMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
+}
+
+// TestModifiedParentMerkleFails ensures that open a verity enabled file in a
+// verity enabled directory fails if the hashes related to the target file in
+// the parent Merkle tree file is modified.
+func TestModifiedParentMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
+}
+
+// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
+// succeeds after enabling verity for it.
+func TestUnmodifiedStatSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
+}
+
+// TestModifiedStatFails checks that getting stat for a file with modified stat
+// should fail.
+func TestModifiedStatFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
+}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 61c78569d..300b7ccce 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -7,11 +7,14 @@ go_library(
srcs = [
"cgroup.go",
"hostmm.go",
+ "membarrier.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
"//pkg/fd",
"//pkg/log",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/hostmm/membarrier.go b/pkg/sentry/hostmm/membarrier.go
new file mode 100644
index 000000000..4468d75f1
--- /dev/null
+++ b/pkg/sentry/hostmm/membarrier.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostmm
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ haveMembarrierGlobal = false
+ haveMembarrierPrivateExpedited = false
+)
+
+func init() {
+ supported, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_QUERY, 0 /* flags */, 0 /* unused */)
+ if e != 0 {
+ if e != syscall.ENOSYS {
+ log.Warningf("membarrier(MEMBARRIER_CMD_QUERY) failed: %s", e.Error())
+ }
+ return
+ }
+ // We don't use MEMBARRIER_CMD_GLOBAL_EXPEDITED because this sends IPIs to
+ // all CPUs running tasks that have previously invoked
+ // MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, which presents a DOS risk.
+ // (MEMBARRIER_CMD_GLOBAL is synchronize_rcu(), i.e. it waits for an RCU
+ // grace period to elapse without bothering other CPUs.
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED sends IPIs only to CPUs running tasks
+ // sharing the caller's MM.)
+ if supported&linux.MEMBARRIER_CMD_GLOBAL != 0 {
+ haveMembarrierGlobal = true
+ }
+ if req := uintptr(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED); supported&req == req {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ log.Warningf("membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) failed: %s", e.Error())
+ } else {
+ haveMembarrierPrivateExpedited = true
+ }
+ }
+}
+
+// HaveGlobalMemoryBarrier returns true if GlobalMemoryBarrier is supported.
+func HaveGlobalMemoryBarrier() bool {
+ return haveMembarrierGlobal
+}
+
+// GlobalMemoryBarrier blocks until "all running threads [in the host OS] have
+// passed through a state where all memory accesses to user-space addresses
+// match program order between entry to and return from [GlobalMemoryBarrier]",
+// as for membarrier(2).
+//
+// Preconditions: HaveGlobalMemoryBarrier() == true.
+func GlobalMemoryBarrier() error {
+ if _, _, e := syscall.Syscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_GLOBAL, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
+
+// HaveProcessMemoryBarrier returns true if ProcessMemoryBarrier is supported.
+func HaveProcessMemoryBarrier() bool {
+ return haveMembarrierPrivateExpedited
+}
+
+// ProcessMemoryBarrier is equivalent to GlobalMemoryBarrier, but only
+// synchronizes with threads sharing a virtual address space (from the host OS'
+// perspective) with the calling thread.
+//
+// Preconditions: HaveProcessMemoryBarrier() == true.
+func ProcessMemoryBarrier() error {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a43c549f1..5de70aecb 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -69,8 +69,8 @@ go_template_instance(
prefix = "socket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*SocketEntry",
- "Linker": "*SocketEntry",
+ "Element": "*SocketRecordVFS1",
+ "Linker": "*SocketRecordVFS1",
},
)
@@ -204,7 +204,6 @@ go_library(
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/amutex",
- "//pkg/binary",
"//pkg/bits",
"//pkg/bpf",
"//pkg/context",
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index d3e76ca7b..060c056df 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -89,7 +89,7 @@ func (kcov *Kcov) TaskWork(t *Task) {
kcov.mu.Lock()
defer kcov.mu.Unlock()
- if kcov.mode != linux.KCOV_TRACE_PC {
+ if kcov.mode != linux.KCOV_MODE_TRACE_PC {
return
}
@@ -146,7 +146,7 @@ func (kcov *Kcov) InitTrace(size uint64) error {
}
// EnableTrace performs the KCOV_ENABLE_TRACE ioctl.
-func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
+func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error {
t := TaskFromContext(ctx)
if t == nil {
panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine")
@@ -160,9 +160,9 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
return syserror.EINVAL
}
- switch traceMode {
+ switch traceKind {
case linux.KCOV_TRACE_PC:
- kcov.mode = traceMode
+ kcov.mode = linux.KCOV_MODE_TRACE_PC
case linux.KCOV_TRACE_CMP:
// We do not support KCOV_MODE_TRACE_CMP.
return syserror.ENOTSUP
@@ -175,6 +175,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
}
kcov.owningTask = t
+ t.SetKcov(kcov)
t.RegisterWork(kcov)
// Clear existing coverage data; the task expects to read only coverage data
@@ -196,26 +197,35 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error {
if t != kcov.owningTask {
return syserror.EINVAL
}
- kcov.owningTask = nil
kcov.mode = linux.KCOV_MODE_INIT
- kcov.resetLocked()
+ kcov.owningTask = nil
+ kcov.mappable = nil
return nil
}
-// Reset is called when the owning task exits.
-func (kcov *Kcov) Reset() {
+// Clear resets the mode and clears the owning task and memory mapping for kcov.
+// It is called when the fd corresponding to kcov is closed. Note that the mode
+// needs to be set so that the next call to kcov.TaskWork() will exit early.
+func (kcov *Kcov) Clear() {
kcov.mu.Lock()
- kcov.resetLocked()
+ kcov.clearLocked()
kcov.mu.Unlock()
}
-// The kcov instance is reset when the owning task exits or when tracing is
-// disabled.
-func (kcov *Kcov) resetLocked() {
+func (kcov *Kcov) clearLocked() {
+ kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
- if kcov.mappable != nil {
- kcov.mappable = nil
- }
+ kcov.mappable = nil
+}
+
+// OnTaskExit is called when the owning task exits. It is similar to
+// kcov.Clear(), except the memory mapping is not cleared, so that the same
+// mapping can be used in the future if kcov is enabled again by another task.
+func (kcov *Kcov) OnTaskExit() {
+ kcov.mu.Lock()
+ kcov.mode = linux.KCOV_MODE_INIT
+ kcov.owningTask = nil
+ kcov.mu.Unlock()
}
// ConfigureMMap is called by the vfs.FileDescription for this kcov instance to
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 08bb5bd12..675506269 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -220,13 +220,18 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // sockets is the list of all network sockets the system. Protected by
- // extMu.
+ // sockets is the list of all network sockets in the system.
+ // Protected by extMu.
+ // TODO(gvisor.dev/issue/1624): Only used by VFS1.
sockets socketList
- // nextSocketEntry is the next entry number to use in sockets. Protected
+ // socketsVFS2 records all network sockets in the system. Protected by
+ // extMu.
+ socketsVFS2 map[*vfs.FileDescription]*SocketRecord
+
+ // nextSocketRecord is the next entry number to use in sockets. Protected
// by extMu.
- nextSocketEntry uint64
+ nextSocketRecord uint64
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
return fmt.Errorf("failed to create sockfs mount: %v", err)
}
k.socketMount = socketMount
+
+ k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
}
return nil
@@ -834,14 +841,16 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
if ctx.args.MountNamespaceVFS2 == nil {
return nil
}
- // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
- return ctx.args.MountNamespaceVFS2.Root()
+ root := ctx.args.MountNamespaceVFS2.Root()
+ root.IncRef()
+ return root
case vfs.CtxMountNamespace:
if ctx.k.globalInit == nil {
return nil
}
- // MountNamespaceVFS2 takes a reference for us.
- return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns.IncRef()
+ return mntns
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case inet.CtxStack:
@@ -897,14 +906,13 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
if VFS2Enabled {
mntnsVFS2 = args.MountNamespaceVFS2
if mntnsVFS2 == nil {
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
+ // Add a reference to the namespace, which is transferred to the new process.
mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
+ mntnsVFS2.IncRef()
}
// Get the root directory from the MountNamespace.
root := mntnsVFS2.Root()
- // The call to newFSContext below will take a reference on root, so we
- // don't need to hold this one.
+ root.IncRef()
defer root.DecRef(ctx)
// Grab the working directory.
@@ -1512,20 +1520,27 @@ func (k *Kernel) SupervisorContext() context.Context {
}
}
-// SocketEntry represents a socket recorded in Kernel.sockets. It implements
+// SocketRecord represents a socket recorded in Kernel.socketsVFS2.
+//
+// +stateify savable
+type SocketRecord struct {
+ k *Kernel
+ Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1.
+ SockVFS2 *vfs.FileDescription // Only used by VFS2.
+ ID uint64 // Socket table entry number.
+}
+
+// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
// +stateify savable
-type SocketEntry struct {
+type SocketRecordVFS1 struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- SockVFS2 *vfs.FileDescription
- ID uint64 // Socket table entry number.
+ SocketRecord
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketEntry) WeakRefGone(context.Context) {
+func (s *SocketRecordVFS1) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1536,9 +1551,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) {
// Precondition: Caller must hold a reference to sock.
func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{k: k, ID: id}
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecordVFS1{
+ SocketRecord: SocketRecord{
+ k: k,
+ ID: id,
+ },
+ }
s.Sock = refs.NewWeakRef(sock, s)
k.sockets.PushBack(s)
k.extMu.Unlock()
@@ -1550,29 +1570,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
// Precondition: Caller must hold a reference to sock.
//
// Note that the socket table will not hold a reference on the
-// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+// vfs.FileDescription.
func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{
+ if _, ok := k.socketsVFS2[sock]; ok {
+ panic(fmt.Sprintf("Socket %p added twice", sock))
+ }
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecord{
k: k,
ID: id,
SockVFS2: sock,
}
- k.sockets.PushBack(s)
+ k.socketsVFS2[sock] = s
+ k.extMu.Unlock()
+}
+
+// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table.
+func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ delete(k.socketsVFS2, sock)
k.extMu.Unlock()
}
// ListSockets returns a snapshot of all sockets.
//
-// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef()
// to get a reference on a socket in the table.
-func (k *Kernel) ListSockets() []*SocketEntry {
+func (k *Kernel) ListSockets() []*SocketRecord {
k.extMu.Lock()
- var socks []*SocketEntry
- for s := k.sockets.Front(); s != nil; s = s.Next() {
- socks = append(socks, s)
+ var socks []*SocketRecord
+ if VFS2Enabled {
+ for _, s := range k.socketsVFS2 {
+ socks = append(socks, s)
+ }
+ } else {
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, &s.SocketRecord)
+ }
}
k.extMu.Unlock()
return socks
@@ -1613,16 +1649,16 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
if ctx.k.globalInit == nil {
return vfs.VirtualDentry{}
}
- mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
- defer mntns.DecRef(ctx)
- // Root() takes a reference on the root dirent for us.
- return mntns.Root()
+ root := ctx.k.GlobalInit().Leader().MountNamespaceVFS2().Root()
+ root.IncRef()
+ return root
case vfs.CtxMountNamespace:
if ctx.k.globalInit == nil {
return nil
}
- // MountNamespaceVFS2() takes a reference for us.
- return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns.IncRef()
+ return mntns
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case inet.CtxStack:
@@ -1703,3 +1739,20 @@ func (k *Kernel) ShmMount() *vfs.Mount {
func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+
+// Release releases resources owned by k.
+//
+// Precondition: This should only be called after the kernel is fully
+// initialized, e.g. after k.Start() has been called.
+func (k *Kernel) Release() {
+ ctx := k.SupervisorContext()
+ if VFS2Enabled {
+ k.hostMount.DecRef(ctx)
+ k.pipeMount.DecRef(ctx)
+ k.shmMount.DecRef(ctx)
+ k.socketMount.DecRef(ctx)
+ k.vfs.Release(ctx)
+ }
+ k.timekeeper.Destroy()
+ k.vdso.Release(ctx)
+}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 449643118..99134e634 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/amutex",
"//pkg/buffer",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 6d58b682f..f665920cb 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -145,9 +146,14 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
v = math.MaxInt32 // Silently truncate.
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ iocc := primitive.IOCopyContext{
+ IO: io,
+ Ctx: ctx,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }
+ _, err := primitive.CopyInt32Out(&iocc, args[2].Pointer(), int32(v))
return 0, err
default:
return 0, syscall.ENOTTY
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index c38c5a40c..387edfa91 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -18,7 +18,6 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
@@ -27,25 +26,18 @@ import (
const maxSyscallFilterInstructions = 1 << 15
-// seccompData is equivalent to struct seccomp_data, which contains the data
-// passed to seccomp-bpf filters.
-type seccompData struct {
- // nr is the system call number.
- nr int32
-
- // arch is an AUDIT_ARCH_* value indicating the system call convention.
- arch uint32
-
- // instructionPointer is the value of the instruction pointer at the time
- // of the system call.
- instructionPointer uint64
-
- // args contains the first 6 system call arguments.
- args [6]uint64
-}
-
-func (d *seccompData) asBPFInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder}
+// dataAsBPFInput returns a serialized BPF program, only valid on the current task
+// goroutine.
+//
+// Note: this is called for every syscall, which is a very hot path.
+func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
+ buf := t.CopyScratchBuffer(d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ // Go-marshal always uses the native byte order.
+ Order: usermem.ByteOrder,
+ }
}
func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
@@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u
}
func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
- data := seccompData{
- nr: sysno,
- arch: t.tc.st.AuditNumber,
- instructionPointer: uint64(ip),
+ data := linux.SeccompData{
+ Nr: sysno,
+ Arch: t.tc.st.AuditNumber,
+ InstructionPointer: uint64(ip),
}
// data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so
// we can't do any slicing tricks or even use copy/append here.
for i, arg := range args {
- if i >= len(data.args) {
+ if i >= len(data.Args) {
break
}
- data.args[i] = arg.Uint64()
+ data.Args[i] = arg.Uint64()
}
- input := data.asBPFInput()
+ input := dataAsBPFInput(t, &data)
ret := uint32(linux.SECCOMP_RET_ALLOW)
f := t.syscallFilters.Load()
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 3eb78e91b..76d472292 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index b07e1c1bd..78f718cfe 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -17,7 +17,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
@@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a436610c9..e90a19cfb 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -735,7 +735,6 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
func (t *Task) IsChrooted() bool {
if VFS2Enabled {
realRoot := t.mountNamespaceVFS2.Root()
- defer realRoot.DecRef(t)
root := t.fsContext.RootDirectoryVFS2()
defer root.DecRef(t)
return root != realRoot
@@ -868,7 +867,6 @@ func (t *Task) MountNamespace() *fs.MountNamespace {
func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
t.mu.Lock()
defer t.mu.Unlock()
- t.mountNamespaceVFS2.IncRef()
return t.mountNamespaceVFS2
}
@@ -917,7 +915,7 @@ func (t *Task) SetKcov(k *Kcov) {
// ResetKcov clears the kcov instance associated with t.
func (t *Task) ResetKcov() {
if t.kcov != nil {
- t.kcov.Reset()
+ t.kcov.OnTaskExit()
t.kcov = nil
}
}
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 9fa528384..d1136461a 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -126,7 +126,11 @@ func (t *Task) SyscallTable() *SyscallTable {
// Preconditions: The caller must be running on the task goroutine, or t.mu
// must be locked.
func (t *Task) Stack() *arch.Stack {
- return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())}
+ return &arch.Stack{
+ Arch: t.Arch(),
+ IO: t.MemoryManager(),
+ Bottom: usermem.Addr(t.Arch().Stack()),
+ }
}
// LoadTaskImage loads a specified file into a new TaskContext.
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index feaa38596..ebdb83061 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -259,7 +259,11 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// Set up the signal handler. If we have a saved signal mask, the signal
// handler should run with the current mask, but sigreturn should restore
// the saved one.
- st := &arch.Stack{t.Arch(), mm, sp}
+ st := &arch.Stack{
+ Arch: t.Arch(),
+ IO: mm,
+ Bottom: sp,
+ }
mask := t.signalMask
if t.haveSavedSignalMask {
mask = t.savedSignalMask
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 5ae5906e8..fdadb52c0 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -265,6 +265,13 @@ func (ns *PIDNamespace) Tasks() []*Task {
return tasks
}
+// NumTasks returns the number of tasks in ns.
+func (ns *PIDNamespace) NumTasks() int {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return len(ns.tids)
+}
+
// ThreadGroups returns a snapshot of the thread groups in ns.
func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup {
return ns.ThreadGroupsAppend(nil)
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index e44a139b3..9bc452e67 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -17,7 +17,6 @@ package kernel
import (
"fmt"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -28,6 +27,8 @@ import (
//
// They are exposed to the VDSO via a parameter page managed by VDSOParamPage,
// which also includes a sequence counter.
+//
+// +marshal
type vdsoParams struct {
monotonicReady uint64
monotonicBaseCycles int64
@@ -68,6 +69,13 @@ type VDSOParamPage struct {
// checked in state_test_util tests, causing this field to change across
// save / restore.
seq uint64
+
+ // copyScratchBuffer is a temporary buffer used to marshal the params before
+ // copying it to the real parameter page. The parameter page is typically
+ // updated at a moderate frequency of ~O(seconds) throughout the lifetime of
+ // the sentry, so reusing this buffer is a good tradeoff between memory
+ // usage and the cost of allocation.
+ copyScratchBuffer []byte
}
// NewVDSOParamPage returns a VDSOParamPage.
@@ -79,7 +87,11 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
- return &VDSOParamPage{mfp: mfp, fr: fr}
+ return &VDSOParamPage{
+ mfp: mfp,
+ fr: fr,
+ copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()),
+ }
}
// access returns a mapping of the param page.
@@ -133,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error {
// Get the new params.
p := f()
- buf := binary.Marshal(nil, usermem.ByteOrder, p)
+ buf := v.copyScratchBuffer[:p.SizeBytes()]
+ p.MarshalUnsafe(buf)
// Skip the sequence counter.
if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 15c88aa7c..c69b62db9 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -122,7 +122,7 @@ func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch
if err != nil {
return nil, err
}
- return &arch.Stack{a, m, ar.End}, nil
+ return &arch.Stack{Arch: a, IO: m, Bottom: ar.End}, nil
}
const (
@@ -247,20 +247,20 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
}
// Push the original filename to the stack, for AT_EXECFN.
- execfn, err := stack.Push(args.Filename)
- if err != nil {
+ if _, err := stack.PushNullTerminatedByteSlice([]byte(args.Filename)); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux())
}
+ execfn := stack.Bottom
// Push 16 random bytes on the stack which AT_RANDOM will point to.
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux())
}
- random, err := stack.Push(b)
- if err != nil {
+ if _, err = stack.PushNullTerminatedByteSlice(b[:]); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux())
}
+ random := stack.Bottom
c := auth.CredentialsFromContext(ctx)
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 05a294fe6..241d87835 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -380,3 +380,9 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
return vdsoAddr, nil
}
+
+// Release drops references on mappings held by v.
+func (v *VDSO) Release(ctx context.Context) {
+ v.ParamPage.DecRef(ctx)
+ v.vdso.DecRef(ctx)
+}
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index a44fa2b95..7fd77925f 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -127,7 +127,7 @@ func (t Translation) FileRange() FileRange {
// Preconditions: Same as Mappable.Translate.
func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
// Verify that the inputs to Mappable.Translate were valid.
- if !required.WellFormed() || required.Length() <= 0 {
+ if !required.WellFormed() || required.Length() == 0 {
panic(fmt.Sprintf("invalid required range: %v", required))
}
if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
@@ -145,7 +145,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required)
}
for i, t := range ts {
- if !t.Source.WellFormed() || t.Source.Length() <= 0 {
+ if !t.Source.WellFormed() || t.Source.Length() == 0 {
return fmt.Errorf("Translation %+v has invalid Source", t)
}
if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 8c9f11cce..92cc87d84 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -235,6 +235,20 @@ type MemoryManager struct {
// vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
vdsoSigReturnAddr uint64
+
+ // membarrierPrivateEnabled is non-zero if EnableMembarrierPrivate has
+ // previously been called. Since, as of this writing,
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED is implemented as a global memory
+ // barrier, membarrierPrivateEnabled has no other effect.
+ //
+ // membarrierPrivateEnabled is accessed using atomic memory operations.
+ membarrierPrivateEnabled uint32
+
+ // membarrierRSeqEnabled is non-zero if EnableMembarrierRSeq has previously
+ // been called.
+ //
+ // membarrierRSeqEnabled is accessed using atomic memory operations.
+ membarrierRSeqEnabled uint32
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 30facebf7..7e5f7de64 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -36,7 +36,7 @@ import (
// * ar.Length() != 0.
func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -100,7 +100,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user
// (i.e. permission checks must have been performed against vmas).
func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -193,7 +193,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR
// getVecPMAsLocked; other clients should call one of those instead.
func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -223,7 +223,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
// Need a pma here.
optAR := vseg.Range().Intersect(pgap.Range())
if checkInvariants {
- if optAR.Length() <= 0 {
+ if optAR.Length() == 0 {
panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap))
}
}
@@ -560,7 +560,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat
// Invalidate implements memmap.MappingSpace.Invalidate.
func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -583,7 +583,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate
// * ar must be page-aligned.
func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -629,7 +629,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
// * ar must be page-aligned.
func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -715,10 +715,10 @@ func Unpin(prs []PinnedRange) {
// * oldAR and newAR must be page-aligned.
func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
if checkInvariants {
- if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() {
+ if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() {
panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
}
- if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() {
+ if !newAR.WellFormed() || newAR.Length() == 0 || !newAR.IsPageAligned() {
panic(fmt.Sprintf("invalid newAR: %v", newAR))
}
if oldAR.Length() > newAR.Length() {
@@ -778,7 +778,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
// into mm.pmas.
func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -831,7 +831,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe
// * pseg.Range().Contains(ar.Start).
func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -1050,7 +1050,7 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if !pseg.Ok() {
panic("terminal pma iterator")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().IsSupersetOf(ar) {
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index a2555ba1a..675efdc7c 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -17,6 +17,7 @@ package mm
import (
"fmt"
mrand "math/rand"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -1274,3 +1275,27 @@ func (mm *MemoryManager) VirtualDataSize() uint64 {
defer mm.mappingMu.RUnlock()
return mm.dataAS
}
+
+// EnableMembarrierPrivate causes future calls to IsMembarrierPrivateEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierPrivate() {
+ atomic.StoreUint32(&mm.membarrierPrivateEnabled, 1)
+}
+
+// IsMembarrierPrivateEnabled returns true if mm.EnableMembarrierPrivate() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierPrivateEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierPrivateEnabled) != 0
+}
+
+// EnableMembarrierRSeq causes future calls to IsMembarrierRSeqEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierRSeq() {
+ atomic.StoreUint32(&mm.membarrierRSeqEnabled, 1)
+}
+
+// IsMembarrierRSeqEnabled returns true if mm.EnableMembarrierRSeq() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierRSeqEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierRSeqEnabled) != 0
+}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index f769d8294..b8df72813 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -266,7 +266,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
// * ar.Length() != 0.
func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -350,7 +350,7 @@ const guardBytes = 256 * usermem.PageSize
// * ar must be page-aligned.
func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -371,7 +371,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange)
// * ar must be page-aligned.
func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -511,7 +511,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Range().IsSupersetOf(ar) {
@@ -536,7 +536,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !mr.WellFormed() || mr.Length() <= 0 {
+ if !mr.WellFormed() || mr.Length() == 0 {
panic(fmt.Sprintf("invalid mr: %v", mr))
}
if !vseg.mappableRange().IsSupersetOf(mr) {
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 209b28053..db7d55ef2 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/context",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/hostmm",
"//pkg/sentry/memmap",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 3970dd81d..dd2bbeb12 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -9,12 +9,12 @@ go_library(
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
- "bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
"bluepill_arm64.go",
"bluepill_arm64.s",
"bluepill_arm64_unsafe.go",
"bluepill_fault.go",
+ "bluepill_impl_amd64.s",
"bluepill_unsafe.go",
"context.go",
"filters_amd64.go",
@@ -56,6 +56,7 @@ go_library(
"//pkg/sentry/time",
"//pkg/sync",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -78,6 +79,15 @@ go_test(
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/sentry/time",
"//pkg/usermem",
],
)
+
+genrule(
+ name = "bluepill_impl_amd64",
+ srcs = ["bluepill_amd64.s"],
+ outs = ["bluepill_impl_amd64.s"],
+ cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 2bc34a435..025ea93b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -19,11 +19,6 @@
// This is guaranteed to be zero.
#define VCPU_CPU 0x0
-// CPU_SELF is the self reference in ring0's percpu.
-//
-// This is guaranteed to be zero.
-#define CPU_SELF 0x0
-
// Context offsets.
//
// Only limited use of the context is done in the assembly stub below, most is
@@ -44,7 +39,7 @@ begin:
LEAQ VCPU_CPU(AX), BX
BYTE CLI;
check_vcpu:
- MOVQ CPU_SELF(GS), CX
+ MOVQ ENTRY_CPU_SELF(GS), CX
CMPQ BX, CX
JE right_vCPU
wrong_vcpu:
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 03a98512e..0a54dd30d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
- return c.runData.readyForInterruptInjection != 0
+ if c.runData.readyForInterruptInjection == 0 {
+ return false
+ }
+
+ if c.runData.ifFlag == 0 {
+ // This is impossible if readyForInterruptInjection is 1.
+ throw("interrupts are disabled")
+ }
+
+ // Disable interrupts if we are in the kernel space.
+ //
+ // When the Sentry switches into the kernel mode, it disables
+ // interrupts. But when goruntime switches on a goroutine which has
+ // been saved in the host mode, it restores flags and this enables
+ // interrupts. See the comment of UserFlagsSet for more details.
+ uregs := userRegs{}
+ err := c.getUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to get user registers")
+ }
+
+ if ring0.IsKernelFlags(uregs.RFLAGS) {
+ uregs.RFLAGS &^= ring0.KernelFlagsClear
+ err = c.setUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to set user registers")
+ }
+ return false
+ }
+ return true
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 979be5d89..eb05950cd 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -62,6 +62,9 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
//
//go:nosplit
func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Increment our counter.
+ atomic.AddUint64(&c.guestExits, 1)
+
// Copy out registers.
bluepillArchExit(c, bluepillArchContext(context))
@@ -89,9 +92,6 @@ func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
c := bluepillArchEnter(bluepillArchContext(context))
- // Increment the number of switches.
- atomic.AddUint32(&c.switches, 1)
-
// Mark this as guest mode.
switch atomic.SwapUint32(&c.state, vCPUGuest|vCPUUser) {
case vCPUUser: // Expected case.
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index 6e6b76416..17268d127 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -15,6 +15,8 @@
package kvm
import (
+ "sync/atomic"
+
pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -75,6 +77,9 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
// Clear the address space.
cpu.active.set(nil)
+ // Increment the number of user exits.
+ atomic.AddUint64(&cpu.userExits, 1)
+
// Release resources.
c.machine.Put(cpu)
diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go
index 7d949f1dd..d3d216aa5 100644
--- a/pkg/sentry/platform/kvm/filters_amd64.go
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
@@ -17,14 +17,23 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: {},
- syscall.SYS_IOCTL: {},
+ syscall.SYS_ARCH_PRCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
index 9245d07c2..21abc2a3d 100644
--- a/pkg/sentry/platform/kvm/filters_arm64.go
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -17,13 +17,22 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_IOCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ae813e24e..dd45ad10b 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -63,6 +63,9 @@ type runData struct {
type KVM struct {
platform.NoCPUPreemptionDetection
+ // KVM never changes mm_structs.
+ platform.UseHostProcessMemoryBarrier
+
// machine is the backing VM.
machine *machine
}
@@ -156,15 +159,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
pageTables := pagetables.New(newAllocator())
- applyPhysicalRegions(func(pr physicalRegion) bool {
- // Map the kernel in the upper half.
- pageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
- return true // Keep iterating.
- })
+ k.machine.mapUpperHalf(pageTables)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 5c4b18899..6abaa21c4 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -26,12 +26,16 @@ const (
_KVM_RUN = 0xae80
_KVM_NMI = 0xae9a
_KVM_CHECK_EXTENSION = 0xae03
+ _KVM_GET_TSC_KHZ = 0xaea3
+ _KVM_SET_TSC_KHZ = 0xaea2
_KVM_INTERRUPT = 0x4004ae86
_KVM_SET_MSRS = 0x4008ae89
_KVM_SET_USER_MEMORY_REGION = 0x4020ae46
_KVM_SET_REGS = 0x4090ae82
_KVM_SET_SREGS = 0x4138ae84
+ _KVM_GET_MSRS = 0xc008ae88
_KVM_GET_REGS = 0x8090ae81
+ _KVM_GET_SREGS = 0x8138ae83
_KVM_GET_SUPPORTED_CPUID = 0xc008ae05
_KVM_SET_CPUID2 = 0x4008ae90
_KVM_SET_SIGNAL_MASK = 0x4004ae8b
@@ -79,11 +83,14 @@ const (
)
// KVM hypercall list.
+//
// Canonical list of hypercalls supported.
const (
// On amd64, it uses 'HLT' to leave the guest.
+ //
// Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
- // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ //
+ // _KVM_HYPERCALL_VMEXIT is only used on arm64 for now.
_KVM_HYPERCALL_VMEXIT int = iota
_KVM_HYPERCALL_MAX
)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 9a7be3655..84df0f878 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -101,13 +101,20 @@ const (
// Arm64: Memory Attribute Indirection Register EL1.
const (
- _MT_DEVICE_nGnRnE = 0
- _MT_DEVICE_nGnRE = 1
- _MT_DEVICE_GRE = 2
- _MT_NORMAL_NC = 3
- _MT_NORMAL = 4
- _MT_NORMAL_WT = 5
- _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_ATTR_DEVICE_nGnRnE = 0x00
+ _MT_ATTR_DEVICE_nGnRE = 0x04
+ _MT_ATTR_DEVICE_GRE = 0x0c
+ _MT_ATTR_NORMAL_NC = 0x44
+ _MT_ATTR_NORMAL_WT = 0xbb
+ _MT_ATTR_NORMAL = 0xff
+ _MT_ATTR_MASK = 0xff
+ _MT_EL1_INIT = (_MT_ATTR_DEVICE_nGnRnE << (_MT_DEVICE_nGnRnE * 8)) | (_MT_ATTR_DEVICE_nGnRE << (_MT_DEVICE_nGnRE * 8)) | (_MT_ATTR_DEVICE_GRE << (_MT_DEVICE_GRE * 8)) | (_MT_ATTR_NORMAL_NC << (_MT_NORMAL_NC * 8)) | (_MT_ATTR_NORMAL << (_MT_NORMAL * 8)) | (_MT_ATTR_NORMAL_WT << (_MT_NORMAL_WT * 8))
)
const (
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 45b3180f1..e58acc071 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -411,9 +412,9 @@ func TestWrongVCPU(t *testing.T) {
// Basic test, one then the other.
bluepill(c1)
bluepill(c2)
- if c2.switches == 0 {
+ if c2.guestExits == 0 {
// Don't allow the test to proceed if this fails.
- t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ t.Fatalf("wrong vCPU#2 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
// Alternate vCPUs; we expect to need to trigger the
@@ -422,11 +423,11 @@ func TestWrongVCPU(t *testing.T) {
bluepill(c1)
bluepill(c2)
}
- if count := c1.switches; count < 90 {
- t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ if count := c1.guestExits; count < 90 {
+ t.Errorf("wrong vCPU#1 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
- if count := c2.switches; count < 90 {
- t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ if count := c2.guestExits; count < 90 {
+ t.Errorf("wrong vCPU#2 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
return false
})
@@ -442,6 +443,22 @@ func TestWrongVCPU(t *testing.T) {
})
}
+func TestRdtsc(t *testing.T) {
+ var i int // Iteration count.
+ kvmTest(t, nil, func(c *vCPU) bool {
+ start := ktime.Rdtsc()
+ bluepill(c)
+ guest := ktime.Rdtsc()
+ redpill()
+ end := ktime.Rdtsc()
+ if start > guest || guest > end {
+ t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end)
+ }
+ i++
+ return i < 100
+ })
+}
+
func BenchmarkApplicationSyscall(b *testing.B) {
var (
i int // Iteration includes machine.Get() / machine.Put().
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 372a4cbd7..61ed24d01 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -103,8 +103,11 @@ type vCPU struct {
// tid is the last set tid.
tid uint64
- // switches is a count of world switches (informational only).
- switches uint32
+ // userExits is the count of user exits.
+ userExits uint64
+
+ // guestExits is the count of guest to host world switches.
+ guestExits uint64
// faults is a count of world faults (informational only).
faults uint32
@@ -127,6 +130,7 @@ type vCPU struct {
// vCPUArchState is the architecture-specific state.
vCPUArchState
+ // dieState holds state related to vCPU death.
dieState dieState
}
@@ -155,7 +159,7 @@ func (m *machine) newVCPU() *vCPU {
fd: int(fd),
machine: m,
}
- c.CPU.Init(&m.kernel, c)
+ c.CPU.Init(&m.kernel, c.id, c)
m.vCPUsByID[c.id] = c
// Ensure the signal mask is correct.
@@ -183,9 +187,6 @@ func newMachine(vm int) (*machine, error) {
// Create the machine.
m := &machine{fd: vm}
m.available.L = &m.mu
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- })
// Pull the maximum vCPUs.
maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
@@ -197,6 +198,9 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ }, m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -219,15 +223,9 @@ func newMachine(vm int) (*machine, error) {
pagetables.MapOpts{AccessType: usermem.AnyAccess},
pr.physical)
- // And keep everything in the upper half.
- m.kernel.PageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
-
return true // Keep iterating.
})
+ m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -365,6 +363,11 @@ func (m *machine) Destroy() {
// Get gets an available vCPU.
//
// This will return with the OS thread locked.
+//
+// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points
+// to the vCPU in which the OS thread TID is running. So if Get() returns with
+// the corrent context in guest, the vCPU of it must be the same as what
+// Get() returns.
func (m *machine) Get() *vCPU {
m.mu.RLock()
runtime.LockOSThread()
@@ -469,6 +472,19 @@ func (m *machine) newDirtySet() *dirtySet {
}
}
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
+ }
+}
+
// lock marks the vCPU as in user mode.
//
// This should only be called directly when known to be safe, i.e. when
@@ -528,6 +544,8 @@ var pid = syscall.Getpid()
//
// This effectively unwinds the state machine.
func (c *vCPU) bounce(forceGuestExit bool) {
+ origGuestExits := atomic.LoadUint64(&c.guestExits)
+ origUserExits := atomic.LoadUint64(&c.userExits)
for {
switch state := atomic.LoadUint32(&c.state); state {
case vCPUReady, vCPUWaiter:
@@ -583,6 +601,14 @@ func (c *vCPU) bounce(forceGuestExit bool) {
// Should not happen: the above is exhaustive.
panic("invalid state")
}
+
+ // Check if we've missed the state transition, but
+ // we can safely return at this point in time.
+ newGuestExits := atomic.LoadUint64(&c.guestExits)
+ newUserExits := atomic.LoadUint64(&c.userExits)
+ if newUserExits != origUserExits && (!forceGuestExit || newGuestExits != origGuestExits) {
+ return
+ }
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index acc823ba6..c67127d95 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -18,14 +18,17 @@ package kvm
import (
"fmt"
+ "math/big"
"reflect"
"runtime/debug"
"syscall"
+ "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -84,19 +87,6 @@ const (
poolPCIDs = 8
)
-// dropPageTables drops cached page table entries.
-func (m *machine) dropPageTables(pt *pagetables.PageTables) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- // Clear from all PCIDs.
- for _, c := range m.vCPUsByID {
- if c != nil && c.PCIDs != nil {
- c.PCIDs.Drop(pt)
- }
- }
-}
-
// initArchState initializes architecture-specific state.
func (c *vCPU) initArchState() error {
var (
@@ -144,6 +134,7 @@ func (c *vCPU) initArchState() error {
// Set the entrypoint for the kernel.
kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RSP = c.StackTop()
kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
// Set the system registers.
@@ -152,8 +143,8 @@ func (c *vCPU) initArchState() error {
}
// Set the user registers.
- if err := c.setUserRegisters(&kernelUserRegs); err != nil {
- return err
+ if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
}
// Allocate some floating point state save area for the local vCPU.
@@ -166,6 +157,133 @@ func (c *vCPU) initArchState() error {
return c.setSystemTime()
}
+// bitsForScaling returns the bits available for storing the fraction component
+// of the TSC scaling ratio. This allows us to replicate the (bad) math done by
+// the kernel below in scaledTSC, and ensure we can compute an exact zero
+// offset in setSystemTime.
+//
+// These constants correspond to kvm_tsc_scaling_ratio_frac_bits.
+var bitsForScaling = func() int64 {
+ fs := cpuid.HostFeatureSet()
+ if fs.Intel() {
+ return 48 // See vmx.c (kvm sources).
+ } else if fs.AMD() {
+ return 32 // See svm.c (svm sources).
+ } else {
+ return 63 // Unknown: theoretical maximum.
+ }
+}()
+
+// scaledTSC returns the host TSC scaled by the given frequency.
+//
+// This assumes a current frequency of 1. We require only the unitless ratio of
+// rawFreq to some current frequency. See setSystemTime for context.
+//
+// The kernel math guarantees that all bits of the multiplication and division
+// will be correctly preserved and applied. However, it is not possible to
+// actually store the ratio correctly. So we need to use the same schema in
+// order to calculate the scaled frequency and get the same result.
+//
+// We can assume that the current frequency is (1), so we are calculating a
+// strict inverse of this value. This simplifies this function considerably.
+//
+// Roughly, the returned value "scaledTSC" will have:
+// scaledTSC/hostTSC == 1/rawFreq
+//
+//go:nosplit
+func scaledTSC(rawFreq uintptr) int64 {
+ scale := int64(1 << bitsForScaling)
+ ratio := big.NewInt(scale / int64(rawFreq))
+ ratio.Mul(ratio, big.NewInt(int64(ktime.Rdtsc())))
+ ratio.Div(ratio, big.NewInt(scale))
+ return ratio.Int64()
+}
+
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ // First, scale down the clock frequency to the lowest value allowed by
+ // the API itself. How low we can go depends on the underlying
+ // hardware, but it is typically ~1/2^48 for Intel, ~1/2^32 for AMD.
+ // Even the lower bound here will take a 4GHz frequency down to 1Hz,
+ // meaning that everything should be able to handle a Khz setting of 1
+ // with bits to spare.
+ //
+ // Note that reducing the clock does not typically require special
+ // capabilities as it is emulated in KVM. We don't actually use this
+ // capability, but it means that this method should be robust to
+ // different hardware configurations.
+ rawFreq, err := c.getTSCFreq()
+ if err != nil {
+ return c.setSystemTimeLegacy()
+ }
+ if err := c.setTSCFreq(1); err != nil {
+ return c.setSystemTimeLegacy()
+ }
+
+ // Always restore the original frequency.
+ defer func() {
+ if err := c.setTSCFreq(rawFreq); err != nil {
+ panic(err.Error())
+ }
+ }()
+
+ // Attempt to set the system time in this compressed world. The
+ // calculation for offset normally looks like:
+ //
+ // offset = target_tsc - kvm_scale_tsc(vcpu, rdtsc());
+ //
+ // So as long as the kvm_scale_tsc component is constant before and
+ // after the call to set the TSC value (and it is passes as the
+ // target_tsc), we will compute an offset value of zero.
+ //
+ // This is effectively cheating to make our "setSystemTime" call so
+ // unbelievably, incredibly fast that we do it "instantly" and all the
+ // calculations result in an offset of zero.
+ lastTSC := scaledTSC(rawFreq)
+ for {
+ if err := c.setTSC(uint64(lastTSC)); err != nil {
+ return err
+ }
+ nextTSC := scaledTSC(rawFreq)
+ if lastTSC == nextTSC {
+ return nil
+ }
+ lastTSC = nextTSC // Try again.
+ }
+}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
+
// nonCanonical generates a canonical address return.
//
//go:nosplit
@@ -345,3 +463,41 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
+
+var execRegions = func() (regions []region) {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
+ return
+ }
+ if vr.accessType.Execute {
+ regions = append(regions, vr.region)
+ }
+ })
+ return
+}()
+
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ for _, r := range execRegions {
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossilbe translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
+ }
+ for start, end := range m.kernel.EntryRegions() {
+ regionLen := end - start
+ physical, length, ok := translateToPhysical(start)
+ if !ok || length < regionLen {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|start),
+ regionLen,
+ pagetables.MapOpts{AccessType: usermem.ReadWrite},
+ physical)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 290f035dd..b430f92c6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -23,7 +23,6 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/time"
)
// loadSegments copies the current segments.
@@ -61,91 +60,63 @@ func (c *vCPU) setCPUID() error {
return nil
}
-// setSystemTime sets the TSC for the vCPU.
+// getTSCFreq gets the TSC frequency.
//
-// This has to make the call many times in order to minimize the intrinsic
-// error in the offset. Unfortunately KVM does not expose a relative offset via
-// the API, so this is an approximation. We do this via an iterative algorithm.
-// This has the advantage that it can generally deal with highly variable
-// system call times and should converge on the correct offset.
-func (c *vCPU) setSystemTime() error {
- const (
- _MSR_IA32_TSC = 0x00000010
- calibrateTries = 10
- )
- registers := modelControlRegisters{
- nmsrs: 1,
- }
- registers.entries[0] = modelControlRegister{
- index: _MSR_IA32_TSC,
+// If mustSucceed is true, then this function panics on error.
+func (c *vCPU) getTSCFreq() (uintptr, error) {
+ rawFreq, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_TSC_KHZ,
+ 0 /* ignored */)
+ if errno != 0 {
+ return 0, errno
}
- target := uint64(^uint32(0))
- for done := 0; done < calibrateTries; {
- start := uint64(time.Rdtsc())
- registers.entries[0].data = start + target
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_MSRS,
- uintptr(unsafe.Pointer(&registers))); errno != 0 {
- return fmt.Errorf("error setting system time: %v", errno)
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result. So we only count attempts
- // within +/- 6.25% of our minimum as an attempt.
- end := uint64(time.Rdtsc())
- if end < start {
- continue // Totally bogus.
- }
- half := (end - start) / 2
- if half < target {
- target = half
- }
- if (half - target) < target/8 {
- done++
- }
+ return rawFreq, nil
+}
+
+// setTSCFreq sets the TSC frequency.
+func (c *vCPU) setTSCFreq(freq uintptr) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_TSC_KHZ,
+ freq /* khz */); errno != 0 {
+ return fmt.Errorf("error setting TSC frequency: %v", errno)
}
return nil
}
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
+// setTSC sets the TSC value.
+func (c *vCPU) setTSC(value uint64) error {
+ const _MSR_IA32_TSC = 0x00000010
+ registers := modelControlRegisters{
+ nmsrs: 1,
}
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
+ registers.entries[0].index = _MSR_IA32_TSC
+ registers.entries[0].data = value
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
+ _KVM_SET_MSRS,
+ uintptr(unsafe.Pointer(&registers))); errno != 0 {
+ return fmt.Errorf("error setting tsc: %v", errno)
}
return nil
}
// setUserRegisters sets user registers in the vCPU.
-func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+//
+//go:nosplit
+func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno {
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_REGS,
uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return fmt.Errorf("error setting user registers: %v", errno)
+ return errno
}
- return nil
+ return 0
}
// getUserRegisters reloads user registers in the vCPU.
@@ -175,3 +146,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
}
return nil
}
+
+// getSystemRegisters sets system registers.
+//
+//go:nosplit
+func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 9db171af9..54837f20c 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -19,6 +19,7 @@ package kvm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -48,6 +49,18 @@ const (
poolPCIDs = 8
)
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+}
+
// Get all read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
var rdonlyRegions []region
@@ -100,19 +113,6 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return phyRegions
}
-// dropPageTables drops cached page table entries.
-func (m *machine) dropPageTables(pt *pagetables.PageTables) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- // Clear from all PCIDs.
- for _, c := range m.vCPUsByID {
- if c.PCIDs != nil {
- c.PCIDs.Drop(pt)
- }
- }
-}
-
// nonCanonical generates a canonical address return.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 537419657..a163f956d 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -191,42 +191,6 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
return nil
}
-// setCPUID sets the CPUID to be used by the guest.
-func (c *vCPU) setCPUID() error {
- return nil
-}
-
-// setSystemTime sets the TSC for the vCPU.
-func (c *vCPU) setSystemTime() error {
- return nil
-}
-
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
- }
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
- }
-
- return nil
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
// Check for canonical addresses.
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 607c82156..1d6ca245a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -143,3 +143,29 @@ func (c *vCPU) waitUntilNot(state uint32) {
panic("futex wait error")
}
}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 530e779b0..dcfe839a7 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/hostmm"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,6 +53,10 @@ type Platform interface {
// can reliably return ErrContextCPUPreempted.
DetectsCPUPreemption() bool
+ // HaveGlobalMemoryBarrier returns true if the GlobalMemoryBarrier method
+ // is supported.
+ HaveGlobalMemoryBarrier() bool
+
// MapUnit returns the alignment used for optional mappings into this
// platform's AddressSpaces. Higher values indicate lower per-page costs
// for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
@@ -97,6 +102,15 @@ type Platform interface {
// called.
PreemptAllCPUs() error
+ // GlobalMemoryBarrier blocks until all threads running application code
+ // (via Context.Switch) and all task goroutines "have passed through a
+ // state where all memory accesses to user-space addresses match program
+ // order between entry to and return from [GlobalMemoryBarrier]", as for
+ // membarrier(2).
+ //
+ // Preconditions: HaveGlobalMemoryBarrier() == true.
+ GlobalMemoryBarrier() error
+
// SyscallFilters returns syscalls made exclusively by this platform.
SyscallFilters() seccomp.SyscallRules
}
@@ -115,6 +129,43 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
panic("This platform does not support CPU preemption detection")
}
+// UseHostGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking equivalent functionality on the
+// host.
+type UseHostGlobalMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ return hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) GlobalMemoryBarrier() error {
+ return hostmm.GlobalMemoryBarrier()
+}
+
+// UseHostProcessMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking a process-local memory barrier.
+// This is faster than UseHostGlobalMemoryBarrier, but is only appropriate for
+// platforms for which application code executes while using the sentry's
+// mm_struct.
+type UseHostProcessMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ // Fall back to a global memory barrier if a process-local one isn't
+ // available.
+ return hostmm.HaveProcessMemoryBarrier() || hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) GlobalMemoryBarrier() error {
+ if hostmm.HaveProcessMemoryBarrier() {
+ return hostmm.ProcessMemoryBarrier()
+ }
+ return hostmm.GlobalMemoryBarrier()
+}
+
// MemoryManager represents an abstraction above the platform address space
// which manages memory mappings and their contents.
type MemoryManager interface {
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index b52d0fbd8..f56aa3b79 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -192,6 +192,7 @@ func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {}
type PTrace struct {
platform.MMapMinAddr
platform.NoCPUPreemptionDetection
+ platform.UseHostGlobalMemoryBarrier
}
// New returns a new ptrace-based implementation of the platform interface.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 9c6c2cf5c..00899273e 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -76,15 +76,41 @@ type KernelOpts struct {
type KernelArchState struct {
KernelOpts
+ // cpuEntries is array of kernelEntry for all cpus
+ cpuEntries []kernelEntry
+
// globalIDT is our set of interrupt gates.
- globalIDT idt64
+ globalIDT *idt64
}
-// CPUArchState contains CPU-specific arch state.
-type CPUArchState struct {
+// kernelEntry contains minimal CPU-specific arch state
+// that can be mapped at the upper of the address space.
+// Malicious APP might steal info from it via CPU bugs.
+type kernelEntry struct {
// stack is the stack used for interrupts on this CPU.
stack [256]byte
+ // scratch space for temporary usage.
+ scratch0 uint64
+
+ // stackTop is the top of the stack.
+ stackTop uint64
+
+ // cpuSelf is back reference to CPU.
+ cpuSelf *CPU
+
+ // kernelCR3 is the cr3 used for sentry kernel.
+ kernelCR3 uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
// errorCode is the error code from the last exception.
errorCode uintptr
@@ -97,11 +123,7 @@ type CPUArchState struct {
// exception.
errorType uintptr
- // gdt is the CPU's descriptor table.
- gdt descriptorTable
-
- // tss is the CPU's task state.
- tss TaskState64
+ *kernelEntry
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index 7fa43c2f5..d87b1fd00 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -36,12 +36,15 @@ func sysenter()
// This must be called prior to sysret/iret.
func swapgs()
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
// sysret returns to userspace from a system call.
//
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *arch.Registers) Vector
+func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *arch.Registers) Vector
+func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s
index 02df38331..f59747df3 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.s
+++ b/pkg/sentry/platform/ring0/entry_amd64.s
@@ -63,6 +63,15 @@
MOVQ offset+PTRACE_RSI(reg), SI; \
MOVQ offset+PTRACE_RDI(reg), DI;
+// WRITE_CR3() writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+#define WRITE_CR3() \
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+
// SWAP_GS swaps the kernel GS (CPU).
#define SWAP_GS() \
BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
@@ -75,15 +84,9 @@
#define SYSRET64() \
BYTE $0x48; BYTE $0x0f; BYTE $0x07;
-// LOAD_KERNEL_ADDRESS loads a kernel address.
-#define LOAD_KERNEL_ADDRESS(from, to) \
- MOVQ from, to; \
- ORQ ·KernelStartAddress(SB), to;
-
// LOAD_KERNEL_STACK loads the kernel stack.
-#define LOAD_KERNEL_STACK(from) \
- LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
- LEAQ CPU_STACK_TOP(SP), SP;
+#define LOAD_KERNEL_STACK(entry) \
+ MOVQ ENTRY_STACK_TOP(entry), SP;
// See kernel.go.
TEXT ·Halt(SB),NOSPLIT,$0
@@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0
SWAP_GS()
RET
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ ORQ ·KernelStartAddress(SB), AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
// See entry_amd64.go.
TEXT ·sysret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+ // save SP AX userCR3 on the kernel stack.
+ MOVQ CPU_ENTRY(BX), BX
+ LOAD_KERNEL_STACK(BX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_RAX(AX)
+ PUSHQ CX
+
// Restore user register state.
REGISTERS_LOAD(AX, 0)
MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
- MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+
+ // restore userCR3, AX, SP.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
+ POPQ SP // Restore SP.
SYSRET64()
// See entry_amd64.go.
TEXT ·iret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
// Build an IRET frame & restore state.
+ MOVQ CPU_ENTRY(BX), BX
LOAD_KERNEL_STACK(BX)
- MOVQ PTRACE_SS(AX), BX; PUSHQ BX
- MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
- MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
- MOVQ PTRACE_CS(AX), DI; PUSHQ DI
- MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
- REGISTERS_LOAD(AX, 0) // Restore most registers.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ PUSHQ PTRACE_SS(AX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_FLAGS(AX)
+ PUSHQ PTRACE_CS(AX)
+ PUSHQ PTRACE_RIP(AX)
+ PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack.
+ PUSHQ CX // Save userCR3 on kernel stack.
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
IRET()
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
// See iret, above.
- MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
- MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
- MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
- MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
- REGISTERS_LOAD(GS, CPU_REGISTERS)
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ CPU_REGISTERS+PTRACE_SS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RSP(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_CS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RIP(AX)
+ REGISTERS_LOAD(AX, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX
IRET()
// See entry_amd64.go.
TEXT ·Start(SB),NOSPLIT,$0
- LOAD_KERNEL_STACK(AX) // Set the stack.
PUSHQ $0x0 // Previous frame pointer.
MOVQ SP, BP // Set frame pointer.
PUSHQ AX // First argument (CPU).
@@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0
// See entry_amd64.go.
TEXT ·sysenter(SB),NOSPLIT,$0
- // Interrupts are always disabled while we're executing in kernel mode
- // and always enabled while executing in user mode. Therefore, we can
- // reliably look at the flags in R11 to determine where this syscall
- // was from.
- TESTL $_RFLAGS_IF, R11
+ // _RFLAGS_IOPL0 is always set in the user mode and it is never set in
+ // the kernel mode. See the comment of UserFlagsSet for more details.
+ TESTL $_RFLAGS_IOPL0, R11
JZ kernel
-
user:
SWAP_GS()
- XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
- MOVQ BX, PTRACE_RAX(AX) // Save everything else.
- MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ CX, PTRACE_RIP(AX)
MOVQ R11, PTRACE_FLAGS(AX)
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ SP, PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value.
+ MOVQ CX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ CX, PTRACE_ORIGRAX(AX)
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks.
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
// Return to the kernel, where the frame is:
//
- // vector (sp+24)
+ // vector (sp+32)
+ // userCR3 (sp+24)
// regs (sp+16)
// cpu (sp+8)
// vcpu.Switch (sp+0)
//
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ $Syscall, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ $Syscall, 32(SP) // Output vector.
RET
kernel:
// We can't restore the original stack, but we can access the registers
// in the CPU state directly. No need for temporary juggling.
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ AX, ENTRY_SCRATCH0(GS)
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ AX // First argument (vCPU).
CALL ·kernelSyscall(SB) // Call the trampoline.
POPQ AX // Pop vCPU.
@@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0
// ERROR_CODE (sp+8)
// VECTOR (sp+0)
//
- TESTL $_RFLAGS_IF, 32(SP)
+ TESTL $_RFLAGS_IOPL0, 32(SP)
JZ kernel
user:
SWAP_GS()
ADDQ $-8, SP // Adjust for flags.
MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ PUSHQ AX // Save user AX on stack.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ POPQ BX // Restore original AX.
MOVQ BX, PTRACE_RAX(AX) // Save it.
MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
@@ -249,34 +299,36 @@ user:
MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
// Copy out and return.
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
MOVQ 0(SP), BX // Load vector.
MOVQ 8(SP), CX // Load error code.
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
- MOVQ BX, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(AX) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
+ MOVQ BX, 32(SP) // Output vector.
RET
kernel:
// As per above, we can save directly.
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+ PUSHQ AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ POPQ BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX)
// Set the error code and adjust the stack.
- MOVQ 8(SP), AX // Load the error code.
- MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 8(SP), BX // Load the error code.
+ MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
MOVQ 0(SP), BX // BX contains the vector.
- ADDQ $48, SP // Drop the exception frame.
// Call the exception trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ BX // Second argument (vector).
PUSHQ AX // First argument (vCPU).
CALL ·kernelException(SB) // Call the trampoline.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 030e4bb9f..274576e2d 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -47,8 +47,9 @@
#define SCTLR_C 1 << 2
#define SCTLR_I 1 << 12
#define SCTLR_UCT 1 << 15
+#define SCTLR_UCI 1 << 26
-#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT | SCTLR_UCI)
// cntkctl_el1: counter-timer kernel control register el1.
#define CNTKCTL_EL0PCTEN 1 << 0
@@ -340,6 +341,8 @@
ADD $16, RSP, RSP; \
MOVD RSV_REG, PTRACE_R18(R20); \
MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MRS TPIDR_EL0, R3; \
+ MOVD R3, PTRACE_TLS(R20); \
WORD $0xd5384003; \ // MRS SPSR_EL1, R3
MOVD R3, PTRACE_PSTATE(R20); \
MRS ELR_EL1, R3; \
@@ -352,6 +355,8 @@
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ MRS TPIDR_EL0, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_TLS(RSV_REG); \
WORD $0xd5384004; \ // MRS SPSR_EL1, R4
MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
MRS ELR_EL1, R4; \
@@ -428,6 +433,8 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MRS TPIDR_EL1, RSV_REG
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+ MRS TPIDR_EL0, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_TLS(RSV_REG)
WORD $0xd5384003 // MRS SPSR_EL1, R3
MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
@@ -454,8 +461,18 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MOVD PTRACE_PSTATE(RSV_REG_APP), R1
WORD $0xd5184001 //MSR R1, SPSR_EL1
+ // need use kernel space address to excute below code, since
+ // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's
+ // ASID.
+ WORD $0x10000061 // ADR R1, do_exit_to_el0
+ ORR $0xffff000000000000, R1, R1
+ JMP (R1)
+
+do_exit_to_el0:
// RSV_REG & RSV_REG_APP will be loaded at the end.
REGISTERS_LOAD(RSV_REG_APP, 0)
+ MOVD PTRACE_TLS(RSV_REG_APP), RSV_REG
+ MSR RSV_REG, TPIDR_EL0
// switch to user pagetable.
MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 549f3d228..9742308d8 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,7 +24,10 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
- visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
+ visibility = [
+ "//pkg/sentry/platform/kvm:__pkg__",
+ "//pkg/sentry/platform/ring0:__pkg__",
+ ],
deps = [
"//pkg/cpuid",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 021693791..264be23d3 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -19,8 +19,8 @@ package ring0
// N.B. that constraints on KernelOpts must be satisfied.
//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts) {
- k.init(opts)
+func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
+ k.init(opts, maxCPUs)
}
// Halt halts execution.
@@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) {
// kernelSyscall is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) {
// kernelException is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) {
// Init initializes a new CPU.
//
// Init allows embedding in other objects.
-func (c *CPU) Init(k *Kernel, hooks Hooks) {
- c.self = c // Set self reference.
- c.kernel = k // Set kernel reference.
- c.init() // Perform architectural init.
+func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init(cpuID) // Perform architectural init.
// Require hooks.
if hooks != nil {
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index d37981dbf..3a9dff4cc 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -18,13 +18,42 @@ package ring0
import (
"encoding/binary"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/usermem"
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
+ entrySize := reflect.TypeOf(kernelEntry{}).Size()
+ var (
+ entries []kernelEntry
+ padding = 1
+ )
+ for {
+ entries = make([]kernelEntry, maxCPUs+padding-1)
+ totalSize := entrySize * uintptr(maxCPUs+padding-1)
+ addr := reflect.ValueOf(&entries[0]).Pointer()
+ if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize {
+ // The runtime forces power-of-2 alignment for allocations, and we are therefore
+ // safe once the first address is aligned and the chunk is at least a full page.
+ break
+ }
+ padding = padding << 1
+ }
+ k.cpuEntries = entries
+
+ k.globalIDT = &idt64{}
+ if reflect.TypeOf(idt64{}).Size() != usermem.PageSize {
+ panic("Size of globalIDT should be PageSize")
+ }
+ if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 {
+ panic("Allocated globalIDT should be page aligned")
+ }
+
// Setup the IDT, which is uniform.
for v, handler := range handlers {
// Allow Breakpoint and Overflow to be called from all
@@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) {
}
}
+func (k *Kernel) EntryRegions() map[uintptr]uintptr {
+ regions := make(map[uintptr]uintptr)
+
+ addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer()
+ size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries))
+ end, _ := usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ addr = reflect.ValueOf(k.globalIDT).Pointer()
+ size = reflect.TypeOf(idt64{}).Size()
+ end, _ = usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ return regions
+}
+
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
+ c.kernelEntry = &c.kernel.cpuEntries[cpuID]
+ c.cpuSelf = c
// Null segment.
c.gdt[0].setNull()
@@ -65,6 +112,7 @@ func (c *CPU) init() {
// Set the kernel stack pointer in the TSS (virtual address).
stackAddr := c.StackTop()
+ c.stackTop = stackAddr
c.tss.rsp0Lo = uint32(stackAddr)
c.tss.rsp0Hi = uint32(stackAddr >> 32)
c.tss.ist1Lo = uint32(stackAddr)
@@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
- kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+ c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID))
// Sanitize registers.
regs := switchOpts.Registers
@@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
- jumpToKernel() // Switch to upper half.
- writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
- vector = iret(c, regs)
+ vector = iret(c, regs, uintptr(userCR3))
} else {
- vector = sysret(c, regs)
+ vector = sysret(c, regs, uintptr(userCR3))
}
- writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
- jumpToUser() // Return to lower half.
SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
@@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
//go:nosplit
func start(c *CPU) {
// Save per-cpu & FS segment.
- WriteGS(kernelAddr(c))
+ WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
// Initialize floating point.
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index 14774c5db..b294ccc7c 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,13 +25,13 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
}
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
// Set the kernel stack pointer(virtual address).
c.registers.Sp = uint64(c.StackTop())
@@ -64,11 +64,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate |= UserFlagsSet
LoadFloatingPoint(switchOpts.FloatingPointState)
- SetTLS(regs.TPIDR_EL0)
kernelExitToEl0()
- regs.TPIDR_EL0 = GetTLS()
SaveFloatingPoint(switchOpts.FloatingPointState)
vector = c.vecCode
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
index ca968a036..0ec5c3bc5 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.go
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -61,21 +61,9 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
-// writeCR3 writes the CR3 value.
-func writeCR3(phys uintptr)
-
-// readCR3 reads the current CR3 value.
-func readCR3() uintptr
-
// readCR2 reads the current CR2 value.
func readCR2() uintptr
-// jumpToKernel jumps to the kernel version of the current RIP.
-func jumpToKernel()
-
-// jumpToUser jumps to the user version of the current RIP.
-func jumpToUser()
-
// fninit initializes the floating point unit.
func fninit()
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
index 75d742750..2fe83568a 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.s
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30; // WRMSR
RET
-// jumpToUser changes execution to the user address.
-//
-// This works by changing the return value to the user version.
-TEXT ·jumpToUser(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- NOTQ BX
- ANDQ BX, SP // Switch the stack.
- ANDQ BX, BP // Switch the frame pointer.
- ANDQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// jumpToKernel changes execution to the kernel address space.
-//
-// This works by changing the return value to the kernel version.
-TEXT ·jumpToKernel(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- ORQ BX, SP // Switch the stack.
- ORQ BX, BP // Switch the frame pointer.
- ORQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// writeCR3 writes the given CR3 value.
-//
-// The code corresponds to:
-//
-// mov %rax, %cr3
-//
-TEXT ·writeCR3(SB),NOSPLIT,$0-8
- MOVQ cr3+0(FP), AX
- BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
- RET
-
-// readCR3 reads the current CR3 value.
-//
-// The code corresponds to:
-//
-// mov %cr3, %rax
-//
-TEXT ·readCR3(SB),NOSPLIT,$0-8
- BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
- MOVQ AX, ret+0(FP)
- RET
-
// readCR2 reads the current CR2 value.
//
// The code corresponds to:
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index b8ab120a0..ca4075b09 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -30,14 +30,21 @@ func Emit(w io.Writer) {
c := &CPU{}
fmt.Fprintf(w, "\n// CPU offsets.\n")
- fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
- fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+
+ e := &kernelEntry{}
+ fmt.Fprintf(w, "\n// CPU entry offsets.\n")
+ fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
+ fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0)
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 1d86b4bcf..45eba960d 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -125,4 +125,5 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_TLS 0x%02x\n", reflect.ValueOf(&p.TPIDR_EL0).Pointer()-reflect.ValueOf(p).Pointer())
}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 6409d1d91..520161755 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -78,7 +78,7 @@ const (
const (
executeDisable = xn
- optionMask = 0xfff | 0xfff<<48
+ optionMask = 0xfff | 0xffff<<48
protDefault = accessed | shared
)
@@ -188,7 +188,7 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) {
v |= mtNormal
} else {
v = v &^ user
- v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress.
+ v |= mtNormal
}
atomic.StoreUintptr((*uintptr)(p), v)
}
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 9da0ea685..34fbc1c35 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -39,7 +39,9 @@ const (
_RFLAGS_AC = 1 << 18
_RFLAGS_NT = 1 << 14
- _RFLAGS_IOPL = 3 << 12
+ _RFLAGS_IOPL0 = 1 << 12
+ _RFLAGS_IOPL1 = 1 << 13
+ _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1
_RFLAGS_DF = 1 << 10
_RFLAGS_IF = 1 << 9
_RFLAGS_STEP = 1 << 8
@@ -67,15 +69,45 @@ const (
KernelFlagsSet = _RFLAGS_RESERVED
// UserFlagsSet are always set in userspace.
- UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF
+ //
+ // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege
+ // level. The Current Privilege Level (CPL) of the task must be less
+ // than or equal to the IOPL in order for the task or program to access
+ // I/O ports.
+ //
+ // Here, _RFLAGS_IOPL0 is used only to determine whether the task is
+ // running in the kernel or userspace mode. In the user mode, the CPL is
+ // always 3 and it doesn't matter what IOPL is set if it is bellow CPL.
+ //
+ // We need to have one bit which will be always different in user and
+ // kernel modes. And we have to remember that even though we have
+ // KernelFlagsClear, we still can see some of these flags in the kernel
+ // mode. This can happen when the goruntime switches on a goroutine
+ // which has been saved in the host mode. On restore, the popf
+ // instruction is used to restore flags and this means that all flags
+ // what the goroutine has in the host mode will be restored in the
+ // kernel mode.
+ //
+ // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set
+ // it in the user mode. So if this flag is set, the task is running in
+ // the user mode and if it isn't set, the task is running in the kernel
+ // mode.
+ UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0
// KernelFlagsClear should always be clear in the kernel.
KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT
// UserFlagsClear are always cleared in userspace.
- UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL
+ UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1
)
+// IsKernelFlags returns true if rflags coresponds to the kernel mode.
+//
+// go:nosplit
+func IsKernelFlags(rflags uint64) bool {
+ return rflags&_RFLAGS_IOPL0 == 0
+}
+
// Vector is an exception vector.
type Vector uintptr
@@ -104,7 +136,7 @@ const (
VirtualizationException
SecurityException = 0x1e
SyscallInt80 = 0x80
- _NR_INTERRUPTS = SyscallInt80 + 1
+ _NR_INTERRUPTS = 0x100
)
// System call vectors.
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 87b077e68..163af329b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -78,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *socketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 0336a32d8..549787955 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,7 +39,7 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an stack.Matcher to an ABI struct.
+ // marshal converts from a stack.Matcher to an ABI struct.
marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
@@ -93,3 +95,71 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
}
return matchMaker.unmarshal(buf, filter)
}
+
+// targetMaker knows how to (un)marshal a target. Once registered,
+// marshalTarget and unmarshalTarget can be used.
+type targetMaker interface {
+ // id uniquely identifies the target.
+ id() stack.TargetID
+
+ // marshal converts from a stack.Target to an ABI struct.
+ marshal(target stack.Target) []byte
+
+ // unmarshal converts from the ABI matcher struct to a stack.Target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+}
+
+// targetMakers maps the TargetID of supported targets to the targetMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var targetMakers = map[stack.TargetID]targetMaker{}
+
+func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
+ tid := stack.TargetID{
+ Name: name,
+ NetworkProtocol: netProto,
+ Revision: rev,
+ }
+ if _, ok := targetMakers[tid]; !ok {
+ return 0, false
+ }
+
+ // Return the highest supported revision unless rev is higher.
+ for _, other := range targetMakers {
+ otherID := other.id()
+ if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
+ rev = uint8(otherID.Revision)
+ }
+ }
+ return rev, true
+}
+
+// registerTargetMaker should be called by target extensions to register them
+// with the netfilter package.
+func registerTargetMaker(tm targetMaker) {
+ if _, ok := targetMakers[tm.id()]; ok {
+ panic(fmt.Sprintf("multiple targets registered with name %q.", tm.id()))
+ }
+ targetMakers[tm.id()] = tm
+}
+
+func marshalTarget(target stack.Target) []byte {
+ targetMaker, ok := targetMakers[target.ID()]
+ if !ok {
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ }
+ return targetMaker.marshal(target)
+}
+
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
+ tid := stack.TargetID{
+ Name: target.Name.String(),
+ NetworkProtocol: filter.NetworkProtocol(),
+ Revision: target.Revision,
+ }
+ targetMaker, ok := targetMakers[tid]
+ if !ok {
+ nflog("unsupported target with name %q", target.Name.String())
+ return nil, syserr.ErrInvalidArgument
+ }
+ return targetMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index e4c55a100..b560fae0d 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -181,18 +181,23 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- target, err := parseTarget(filter, optVal[:targetSize])
- if err != nil {
- nflog("failed to parse target: %v", err)
- return nil, syserr.ErrInvalidArgument
- }
- optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, stack.Rule{
+ rule := stack.Rule{
Filter: filter,
- Target: target,
Matchers: matchers,
- })
+ }
+
+ {
+ target, err := parseTarget(filter, optVal[:targetSize], false /* ipv6 */)
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, err
+ }
+ rule.Target = target
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, rule)
offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 3b2c1becd..4253f7bf4 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -184,18 +184,23 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- target, err := parseTarget(filter, optVal[:targetSize])
- if err != nil {
- nflog("failed to parse target: %v", err)
- return nil, syserr.ErrInvalidArgument
- }
- optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, stack.Rule{
+ rule := stack.Rule{
Filter: filter,
- Target: target,
Matchers: matchers,
- })
+ }
+
+ {
+ target, err := parseTarget(filter, optVal[:targetSize], true /* ipv6 */)
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, err
+ }
+ rule.Target = target
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, rule)
offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 871ea80ee..904a12e38 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -146,10 +147,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
case stack.FilterTable:
table = stack.EmptyFilterTable()
case stack.NATTable:
- if ipv6 {
- nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)")
- return syserr.ErrInvalidArgument
- }
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
@@ -199,7 +196,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// Check the user chains.
for ruleIdx, rule := range table.Rules {
- if _, ok := rule.Target.(stack.UserChainTarget); !ok {
+ if _, ok := rule.Target.(*stack.UserChainTarget); !ok {
continue
}
@@ -220,7 +217,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// Set each jump to point to the appropriate rule. Right now they hold byte
// offsets.
for ruleIdx, rule := range table.Rules {
- jump, ok := rule.Target.(JumpTarget)
+ jump, ok := rule.Target.(*JumpTarget)
if !ok {
continue
}
@@ -311,7 +308,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case stack.AcceptTarget, stack.DropTarget:
+ case *stack.AcceptTarget, *stack.DropTarget:
return true
default:
return false
@@ -322,7 +319,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(stack.AcceptTarget)
+ _, ok := rule.Target.(*stack.AcceptTarget)
return ok
}
@@ -341,3 +338,20 @@ func hookFromLinux(hook int) stack.Hook {
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
+
+// TargetRevision returns a linux.XTGetRevision for a given target. It sets
+// Revision to the highest supported value, unless the provided revision number
+// is larger.
+func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) {
+ // Read in the target name and version.
+ var rev linux.XTGetRevision
+ if _, err := rev.CopyIn(t, revPtr); err != nil {
+ return linux.XTGetRevision{}, syserr.FromError(err)
+ }
+ maxSupported, ok := targetRevision(rev.Name.String(), netProto, rev.Revision)
+ if !ok {
+ return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported
+ }
+ rev.Revision = maxSupported
+ return rev, nil
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 87e41abd8..0e14447fe 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,255 +15,357 @@
package netfilter
import (
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
-// errorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const errorTargetName = "ERROR"
+func init() {
+ // Standard targets include ACCEPT, DROP, RETURN, and JUMP.
+ registerTargetMaker(&standardTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&standardTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+
+ // Both user chains and actual errors are represented in iptables by
+ // error targets.
+ registerTargetMaker(&errorTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&errorTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+
+ registerTargetMaker(&redirectTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&nfNATTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+}
-// redirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const redirectTargetName = "REDIRECT"
+type standardTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
-func marshalTarget(target stack.Target) []byte {
+func (sm *standardTargetMaker) id() stack.TargetID {
+ // Standard targets have the empty string as a name and no revisions.
+ return stack.TargetID{
+ NetworkProtocol: sm.NetworkProtocol,
+ }
+}
+func (*standardTargetMaker) marshal(target stack.Target) []byte {
+ // Translate verdicts the same way as the iptables tool.
+ var verdict int32
switch tg := target.(type) {
- case stack.AcceptTarget:
- return marshalStandardTarget(stack.RuleAccept)
- case stack.DropTarget:
- return marshalStandardTarget(stack.RuleDrop)
- case stack.ErrorTarget:
- return marshalErrorTarget(errorTargetName)
- case stack.UserChainTarget:
- return marshalErrorTarget(tg.Name)
- case stack.ReturnTarget:
- return marshalStandardTarget(stack.RuleReturn)
- case stack.RedirectTarget:
- return marshalRedirectTarget(tg)
- case JumpTarget:
- return marshalJumpTarget(tg)
+ case *stack.AcceptTarget:
+ verdict = -linux.NF_ACCEPT - 1
+ case *stack.DropTarget:
+ verdict = -linux.NF_DROP - 1
+ case *stack.ReturnTarget:
+ verdict = linux.NF_RETURN
+ case *JumpTarget:
+ verdict = int32(tg.Offset)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
-}
-
-func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
- nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
- target := linux.XTStandardTarget{
+ xt := linux.XTStandardTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTStandardTarget,
},
- Verdict: translateFromStandardVerdict(verdict),
+ Verdict: verdict,
}
ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
+}
+
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) != linux.SizeOfXTStandardTarget {
+ nflog("buf has wrong size for standard target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = buf[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict, filter.NetworkProtocol())
+ }
+ // A verdict >= 0 indicates a jump.
+ return &JumpTarget{
+ Offset: uint32(standardTarget.Verdict),
+ NetworkProtocol: filter.NetworkProtocol(),
+ }, nil
+}
+
+type errorTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (em *errorTargetMaker) id() stack.TargetID {
+ // Error targets have no revision.
+ return stack.TargetID{
+ Name: stack.ErrorTargetName,
+ NetworkProtocol: em.NetworkProtocol,
+ }
}
-func marshalErrorTarget(errorName string) []byte {
+func (*errorTargetMaker) marshal(target stack.Target) []byte {
+ var errorName string
+ switch tg := target.(type) {
+ case *stack.ErrorTarget:
+ errorName = stack.ErrorTargetName
+ case *stack.UserChainTarget:
+ errorName = tg.Name
+ default:
+ panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
+ }
+
// This is an error target named error
- target := linux.XTErrorTarget{
+ xt := linux.XTErrorTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTErrorTarget,
},
}
- copy(target.Name[:], errorName)
- copy(target.Target.Name[:], errorTargetName)
+ copy(xt.Name[:], errorName)
+ copy(xt.Target.Name[:], stack.ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) != linux.SizeOfXTErrorTarget {
+ nflog("buf has insufficient size for error target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = buf[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named stack.ErrorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch name := errorTarget.Name.String(); name {
+ case stack.ErrorTargetName:
+ return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ default:
+ // User defined chain.
+ return &stack.UserChainTarget{
+ Name: name,
+ NetworkProtocol: filter.NetworkProtocol(),
+ }, nil
+ }
+}
+
+type redirectTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *redirectTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*redirectTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
// This is a redirect target named redirect
- target := linux.XTRedirectTarget{
+ xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(target.Target.Name[:], redirectTargetName)
+ copy(xt.Target.Name[:], stack.RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
- target.NfRange.RangeSize = 1
- if rt.RangeProtoSpecified {
- target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeSize = 1
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
+}
+
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) < linux.SizeOfXTRedirectTarget {
+ nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("redirectTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = buf[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+
+ // RangeSize should be 1.
+ nfRange := redirectTarget.NfRange
+ if nfRange.RangeSize != 1 {
+ nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
}
- // Convert port from little endian to big endian.
- port := make([]byte, 2)
- binary.LittleEndian.PutUint16(port, rt.MinPort)
- target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
- binary.LittleEndian.PutUint16(port, rt.MaxPort)
- target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
}
-func marshalJumpTarget(jt JumpTarget) []byte {
- nflog("convert to binary: marshalling jump target")
+type nfNATTarget struct {
+ Target linux.XTEntryTarget
+ Range linux.NFNATRange
+}
- // The target's name will be the empty string.
- target := linux.XTStandardTarget{
+const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+
+type nfNATTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *nfNATTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
+ nt := nfNATTarget{
Target: linux.XTEntryTarget{
- TargetSize: linux.SizeOfXTStandardTarget,
+ TargetSize: nfNATMarhsalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
- // Verdict is overloaded by the ABI. When positive, it holds
- // the jump offset from the start of the table.
- Verdict: int32(jt.Offset),
}
+ copy(nt.Target.Name[:], stack.RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.Addr)
+ copy(nt.Range.MaxAddr[:], rt.Addr)
- ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ nt.Range.MinProto = htons(rt.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarhsalledSize)
+ return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-// translateFromStandardVerdict translates verdicts the same way as the iptables
-// tool.
-func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
- switch verdict {
- case stack.RuleAccept:
- return -linux.NF_ACCEPT - 1
- case stack.RuleDrop:
- return -linux.NF_DROP - 1
- case stack.RuleReturn:
- return linux.NF_RETURN
- default:
- // TODO(gvisor.dev/issue/170): Support Jump.
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if size := nfNATMarhsalledSize; len(buf) < size {
+ nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
}
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("nfNATTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
+ binary.Unmarshal(buf, usermem.ByteOrder, &natRange)
+
+ // We don't support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("nfNATTargetMaker: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("nfNATTargetMaker: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/3549): Check for other flags.
+ // For now, redirect target only supports destination change.
+ if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ }
+
+ return &target, nil
}
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32) (stack.Target, error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return stack.AcceptTarget{}, nil
+ return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
case -linux.NF_DROP - 1:
- return stack.DropTarget{}, nil
+ return &stack.DropTarget{NetworkProtocol: netProto}, nil
case -linux.NF_QUEUE - 1:
- return nil, errors.New("unsupported iptables verdict QUEUE")
+ nflog("unsupported iptables verdict QUEUE")
+ return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return stack.ReturnTarget{}, nil
+ return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
default:
- return nil, fmt.Errorf("unknown iptables verdict %d", val)
+ nflog("unknown iptables verdict %d", val)
+ return nil, syserr.ErrInvalidArgument
}
}
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.Target, *syserr.Error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
- return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
+ nflog("optVal has insufficient size for entry target %d", len(optVal))
+ return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
buf := optVal[:linux.SizeOfXTEntryTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &target)
- switch target.Name.String() {
- case "":
- // Standard target.
- if len(optVal) != linux.SizeOfXTStandardTarget {
- return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
- }
- var standardTarget linux.XTStandardTarget
- buf = optVal[:linux.SizeOfXTStandardTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
-
- if standardTarget.Verdict < 0 {
- // A Verdict < 0 indicates a non-jump verdict.
- return translateToStandardTarget(standardTarget.Verdict)
- }
- // A verdict >= 0 indicates a jump.
- return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
-
- case errorTargetName:
- // Error target.
- if len(optVal) != linux.SizeOfXTErrorTarget {
- return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
- }
- var errorTarget linux.XTErrorTarget
- buf = optVal[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
-
- // Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named errorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
- // * To mark the start of a user defined chain. These
- // rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case errorTargetName:
- nflog("set entries: error target")
- return stack.ErrorTarget{}, nil
- default:
- // User defined chain.
- nflog("set entries: user-defined target %q", name)
- return stack.UserChainTarget{Name: name}, nil
- }
-
- case redirectTargetName:
- // Redirect target.
- if len(optVal) < linux.SizeOfXTRedirectTarget {
- return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
- }
-
- if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
- return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p)
- }
-
- var redirectTarget linux.XTRedirectTarget
- buf = optVal[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
-
- // Copy linux.XTRedirectTarget to stack.RedirectTarget.
- var target stack.RedirectTarget
- nfRange := redirectTarget.NfRange
-
- // RangeSize should be 1.
- if nfRange.RangeSize != 1 {
- return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize)
- }
-
- // TODO(gvisor.dev/issue/170): Check if the flags are valid.
- // Also check if we need to map ports or IP.
- // For now, redirect target only supports destination port change.
- // Port range and IP range are not supported yet.
- if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags)
- }
- target.RangeProtoSpecified = true
-
- target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
- target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
-
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
- if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
- return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
- }
-
- // Convert port from big endian to little endian.
- port := make([]byte, 2)
- binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
- target.MinPort = binary.LittleEndian.Uint16(port)
-
- binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
- target.MaxPort = binary.LittleEndian.Uint16(port)
- return target, nil
- }
- // Unknown target.
- return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String())
+ return unmarshalTarget(target, filter, optVal)
}
// JumpTarget implements stack.Target.
@@ -274,9 +376,31 @@ type JumpTarget struct {
// RuleNum is the rule to jump to.
RuleNum int
+
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (jt *JumpTarget) ID() stack.TargetID {
+ return stack.TargetID{
+ NetworkProtocol: jt.NetworkProtocol,
+ }
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
+
+func ntohs(port uint16) uint16 {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, port)
+ return usermem.ByteOrder.Uint16(buf)
+}
+
+func htons(port uint16) uint16 {
+ buf := make([]byte, 2)
+ usermem.ByteOrder.PutUint16(buf, port)
+ return binary.BigEndian.Uint16(buf)
+}
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index a38d25da9..c83b23242 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
return fd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 6fede181a..87e30d742 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -198,7 +198,6 @@ var Metrics = tcpip.Stats{
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
- InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
@@ -1513,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &vP, nil
case linux.IP6T_ORIGINAL_DST:
- // TODO(gvisor.dev/issue/170): ip6tables.
- return nil, syserr.ErrInvalidArgument
+ if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
if outLen < linux.SizeOfIPTGetinfo {
@@ -1556,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
return &entries, nil
+ case linux.IP6T_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1719,6 +1747,26 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
}
return &entries, nil
+ case linux.IPT_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv4 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIP(t, name)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index c0212ad76..4c6791fff 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index a89583dad..cc7408698 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -7,10 +7,21 @@ go_template_instance(
name = "socket_refs",
out = "socket_refs.go",
package = "unix",
- prefix = "socketOpsCommon",
+ prefix = "socketOperations",
template = "//pkg/refs_vfs2:refs_template",
types = {
- "T": "socketOpsCommon",
+ "T": "SocketOperations",
+ },
+)
+
+go_template_instance(
+ name = "socket_vfs2_refs",
+ out = "socket_vfs2_refs.go",
+ package = "unix",
+ prefix = "socketVFS2",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "SocketVFS2",
},
)
@@ -20,6 +31,7 @@ go_library(
"device.go",
"io.go",
"socket_refs.go",
+ "socket_vfs2_refs.go",
"unix.go",
"unix_vfs2.go",
],
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 917055cea..f80011ce4 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -55,6 +55,7 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socketOperationsRefs
socketOpsCommon
}
@@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
return fs.NewFile(ctx, d, flags, &s)
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef(ctx context.Context) {
+ s.socketOperationsRefs.DecRef(func() {
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
//
// +stateify savable
type socketOpsCommon struct {
- socketOpsCommonRefs
socket.SendReceiveTimeout
ep transport.Endpoint
@@ -101,23 +118,6 @@ type socketOpsCommon struct {
abstractNamespace *kernel.AbstractSocketNamespace
}
-// DecRef implements RefCounter.DecRef.
-func (s *socketOpsCommon) DecRef(ctx context.Context) {
- s.socketOpsCommonRefs.DecRef(func() {
- s.ep.Close(ctx)
- if s.abstractNamespace != nil {
- s.abstractNamespace.Remove(s.abstractName, s)
- }
- })
-}
-
-// Release implemements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(ctx context.Context) {
- // Release only decrements a reference on s because s may be referenced in
- // the abstract socket namespace.
- s.DecRef(ctx)
-}
-
func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 8b1abd922..3345124cc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -45,6 +45,7 @@ type SocketVFS2 struct {
vfs.DentryMetadataFileDescriptionImpl
vfs.LockFD
+ socketVFS2Refs
socketOpsCommon
}
@@ -91,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
return vfsfd, nil
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketVFS2) DecRef(ctx context.Context) {
+ s.socketVFS2Refs.DecRef(func() {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 52281ccc2..396744597 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -17,7 +17,6 @@
package strace
import (
- "encoding/binary"
"fmt"
"strconv"
"strings"
@@ -294,7 +293,7 @@ func itimerval(t *kernel.Task, addr usermem.Addr) string {
}
interval := timeval(t, addr)
- value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{})))
+ value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
@@ -304,7 +303,7 @@ func itimerspec(t *kernel.Task, addr usermem.Addr) string {
}
interval := timespec(t, addr)
- value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{})))
+ value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 75752b2e6..a2e441448 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -21,6 +21,7 @@ go_library(
"sys_identity.go",
"sys_inotify.go",
"sys_lseek.go",
+ "sys_membarrier.go",
"sys_mempolicy.go",
"sys_mmap.go",
"sys_mount.go",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 5f26697d2..9c9def7cd 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -376,7 +376,7 @@ var AMD64 = &kernel.SyscallTable{
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
322: syscalls.Supported("execveat", Execveat),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls implemented after 325 are "backports" from versions
@@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{
96: syscalls.Supported("set_tid_address", SetTidAddress),
97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 99: syscalls.Supported("set_robust_list", SetRobustList),
+ 100: syscalls.Supported("get_robust_list", GetRobustList),
101: syscalls.Supported("nanosleep", Nanosleep),
102: syscalls.Supported("getitimer", Getitimer),
103: syscalls.Supported("setitimer", Setitimer),
@@ -695,7 +695,7 @@ var ARM64 = &kernel.SyscallTable{
280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
281: syscalls.Supported("execveat", Execveat),
282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls after 284 are "backports" from versions of Linux after 4.4.
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 98331eb3c..519066a47 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -84,6 +84,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
}
rel = f.Dirent
if !fs.IsDir(rel.Inode.StableAttr) {
+ f.DecRef(t)
return syserror.ENOTDIR
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go
new file mode 100644
index 000000000..63ee5d435
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_membarrier.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Membarrier implements syscall membarrier(2).
+func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ cmd := args[0].Int()
+ flags := args[1].Uint()
+
+ switch cmd {
+ case linux.MEMBARRIER_CMD_QUERY:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ var supportedCommands uintptr
+ if t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ supportedCommands |= linux.MEMBARRIER_CMD_GLOBAL |
+ linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED
+ }
+ if t.RSeqAvailable() {
+ supportedCommands |= linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ
+ }
+ return supportedCommands, nil, nil
+ case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier()
+ case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ // no-op
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierPrivate()
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ:
+ if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.MemoryManager().IsMembarrierRSeqEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ // MEMBARRIER_CMD_FLAG_CPU and cpu_id are ignored since we don't have
+ // the ability to preempt specific CPUs.
+ return 0, nil, t.Kernel().Platform.PreemptAllCPUs()
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierRSeq()
+ return 0, nil, nil
+ default:
+ // Probably a command we don't implement.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 674d341b6..6320593f0 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -26,8 +26,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
addr := args[0].Pointer()
mf := t.Kernel().MemoryFile()
- mf.UpdateUsage()
- _, totalUsage := usage.MemoryAccounting.Copy()
+ mfUsage, err := mf.TotalUsage()
+ if err != nil {
+ return 0, nil, err
+ }
+ memStats, _ := usage.MemoryAccounting.Copy()
+ totalUsage := mfUsage + memStats.Mapped
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
memFree := totalSize - totalUsage
if memFree > totalSize {
@@ -37,12 +41,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Only a subset of the fields in sysinfo_t make sense to return.
si := linux.Sysinfo{
- Procs: uint16(len(t.PIDNamespace().Tasks())),
+ Procs: uint16(t.Kernel().TaskSet().Root.NumTasks()),
Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
TotalRAM: totalSize,
FreeRAM: memFree,
Unit: 1,
}
- _, err := si.CopyOut(t, addr)
+ _, err = si.CopyOut(t, addr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index 066ee0863..c8ce2aabc 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -110,8 +110,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user
}
// Load the new TaskContext.
- mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
- defer mntns.DecRef(t)
+ mntns := t.MountNamespaceVFS2()
wd := t.FSContext().WorkingDirectoryVFS2()
defer wd.DecRef(t)
remainingTraversals := uint(linux.MaxSymlinkTraversals)
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 0df3bd449..c50fd97eb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -163,6 +163,7 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
s.Table[5] = syscalls.Supported("setxattr", SetXattr)
s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
@@ -200,6 +201,7 @@ func Override() {
s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[45] = syscalls.Supported("truncate", Truncate)
s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
s.Table[48] = syscalls.Supported("faccessat", Faccessat)
s.Table[49] = syscalls.Supported("chdir", Chdir)
s.Table[50] = syscalls.Supported("fchdir", Fchdir)
@@ -221,12 +223,14 @@ func Override() {
s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
s.Table[69] = syscalls.Supported("preadv", Preadv)
s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[71] = syscalls.Supported("sendfile", Sendfile)
s.Table[72] = syscalls.Supported("pselect", Pselect)
s.Table[73] = syscalls.Supported("ppoll", Ppoll)
s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
s.Table[76] = syscalls.Supported("splice", Splice)
s.Table[77] = syscalls.Supported("tee", Tee)
s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[79] = syscalls.Supported("newfstatat", Newfstatat)
s.Table[80] = syscalls.Supported("fstat", Fstat)
s.Table[81] = syscalls.Supported("sync", Sync)
s.Table[82] = syscalls.Supported("fsync", Fsync)
@@ -251,8 +255,10 @@ func Override() {
s.Table[210] = syscalls.Supported("shutdown", Shutdown)
s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[213] = syscalls.Supported("readahead", Readahead)
s.Table[221] = syscalls.Supported("execve", Execve)
s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil)
s.Table[242] = syscalls.Supported("accept4", Accept4)
s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
s.Table[267] = syscalls.Supported("syncfs", Syncfs)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 8093ca55c..c855608db 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -92,7 +92,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index bdfd3ca8f..7ad0eaf86 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -61,11 +61,14 @@ func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *au
panic("cannot instaniate an anon filesystem")
}
-// Name implemenents FilesystemType.Name.
+// Name implements FilesystemType.Name.
func (anonFilesystemType) Name() string {
return "none"
}
+// Release implemenents FilesystemType.Release.
+func (anonFilesystemType) Release(ctx context.Context) {}
+
// anonFilesystem is the implementation of FilesystemImpl that backs
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
//
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 1eba0270f..183957ad8 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -827,7 +827,7 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
}
// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/etc.
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
type FileReadWriteSeeker struct {
FD *FileDescription
Ctx context.Context
@@ -835,11 +835,18 @@ type FileReadWriteSeeker struct {
WOpts WriteOptions
}
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
// Read implements io.ReadWriteSeeker.Read.
func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
dst := usermem.BytesIOSequence(p)
- ret, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(ret), err
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
}
// Seek implements io.ReadWriteSeeker.Seek.
@@ -847,9 +854,16 @@ func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
return f.FD.Seek(f.Ctx, offset, int32(whence))
}
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
// Write implements io.ReadWriteSeeker.Write.
func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
buf := usermem.BytesIOSequence(p)
- ret, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(ret), err
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
}
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index bc19db1d5..9d54cc4ed 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -33,6 +33,9 @@ type FilesystemType interface {
// Name returns the name of this FilesystemType.
Name() string
+
+ // Release releases all resources held by this FilesystemType.
+ Release(ctx context.Context)
}
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index dfc3ae6c0..78f115bfa 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -46,8 +46,9 @@ import (
// +stateify savable
type Mount struct {
// vfs, fs, root are immutable. References are held on fs and root.
+ // Note that for a disconnected mount, root may be nil.
//
- // Invariant: root belongs to fs.
+ // Invariant: if not nil, root belongs to fs.
vfs *VirtualFilesystem
fs *Filesystem
root *Dentry
@@ -498,7 +499,9 @@ func (mnt *Mount) DecRef(ctx context.Context) {
mnt.vfs.mounts.seq.EndWrite()
mnt.vfs.mountMu.Unlock()
}
- mnt.root.DecRef(ctx)
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
+ }
mnt.fs.DecRef(ctx)
if vd.Ok() {
vd.DecRef(ctx)
@@ -724,14 +727,12 @@ func (mnt *Mount) Root() *Dentry {
return mnt.root
}
-// Root returns mntns' root. A reference is taken on the returned
-// VirtualDentry.
+// Root returns mntns' root. It does not take a reference on the returned Dentry.
func (mntns *MountNamespace) Root() VirtualDentry {
vd := VirtualDentry{
mount: mntns.root,
dentry: mntns.root.root,
}
- vd.IncRef()
return vd
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 5bd756ea5..38d2701d2 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -158,6 +158,16 @@ func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
return nil
}
+// Release drops references on filesystem objects held by vfs.
+//
+// Precondition: This must be called after VFS.Init() has succeeded.
+func (vfs *VirtualFilesystem) Release(ctx context.Context) {
+ vfs.anonMount.DecRef(ctx)
+ for _, fst := range vfs.fsTypes {
+ fst.fsType.Release(ctx)
+ }
+}
+
// PathOperation specifies the path operated on by a VFS method.
//
// PathOperation is passed to VFS methods by pointer to reduce memory copying:
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index ea0c5413d..8db70a700 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -84,8 +84,8 @@ type VectorisedView struct {
size int
}
-// NewVectorisedView creates a new vectorised view from an already-allocated slice
-// of View and sets its size.
+// NewVectorisedView creates a new vectorised view from an already-allocated
+// slice of View and sets its size.
func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
@@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) {
}
// Clone returns a clone of this VectorisedView.
-// If the buffer argument is large enough to contain all the Views of this VectorisedView,
-// the method will avoid allocations and use the buffer to store the Views of the clone.
+// If the buffer argument is large enough to contain all the Views of this
+// VectorisedView, the method will avoid allocations and use the buffer to
+// store the Views of the clone.
func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
@@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) {
return newFirst, true
}
-// Size returns the size in bytes of the entire content stored in the vectorised view.
+// Size returns the size in bytes of the entire content stored in the
+// vectorised view.
func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
+ return vv.ToOwnedView()
+}
+
+// ToOwnedView returns a single view containing the content of the vectorised
+// view that vv does not own.
+func (vv *VectorisedView) ToOwnedView() View {
u := make([]byte, 0, vv.size)
for _, v := range vv.views {
u = append(u, v...)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 19627fa9b..d4d785cca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
+ }
+ }
+}
+
+// IPv4Options returns a checker that checks the options in an IPv4 packet.
+func IPv4Options(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ options := ip.Options()
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(options) == 0 {
+ return
+ }
+ if diff := cmp.Diff(want, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
}
}
}
@@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker {
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
@@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker {
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
@@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
+// not contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
@@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker {
}
}
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
+// and potentially additional ICMPv4 header fields.
func ICMPv4(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
@@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
index 1193f1d7d..f7a4fbde1 100644
--- a/pkg/tcpip/faketime/faketime.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -24,6 +24,26 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
+// NullClock implements a clock that never advances.
+type NullClock struct{}
+
+var _ tcpip.Clock = (*NullClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (*NullClock) NowNanoseconds() int64 {
+ return 0
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (*NullClock) NowMonotonic() int64 {
+ return 0
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer {
+ return nil
+}
+
// ManualClock implements tcpip.Clock and only advances manually with Advance
// method.
type ManualClock struct {
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index eaface8cb..95ade0e5c 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
-// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// IsMulticastEthernetAddress returns true if the address is a multicast
+// ethernet address.
+func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
+}
+
+// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
- // Must be of the right length.
if len(addr) != EthernetAddressSize {
return false
}
- // Must not be unspecified.
if addr == unspecifiedEthernetAddress {
return false
}
- // Must not be a multicast.
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
- // addr is a valid unicast ethernet address.
return true
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 14413f2ce..3bc8b2b21 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
}
}
+func TestIsMulticastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ true,
+ },
+ {
+ "Unicast",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsMulticastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
+
func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c00bcadfb..504408878 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
-// SetPointer sets the pointer field in a Parameter error packet.
-// This is the first byte of the type specific data field.
-func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
-
-// SetTypeSpecific sets the full 32 bit type specific data field.
-func (b ICMPv4) SetTypeSpecific(val uint32) {
- binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val)
-}
-
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4eb5abd79..4303fc5d5 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -49,11 +49,6 @@ const (
// neighbor advertisement packet.
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
- // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
- // including the NDP Target Link Layer option for an Ethernet
- // address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
-
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
@@ -156,9 +151,14 @@ const (
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
+ // ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
- ICMPv6UnknownHeader ICMPv6Code = 1
- ICMPv6UnknownOption ICMPv6Code = 2
+
+ // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
+ ICMPv6UnknownHeader ICMPv6Code = 1
+
+ // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
+ ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
@@ -177,7 +177,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
-// SetTypeSpecific sets the full 32 bit type specific data field.
+// TypeSpecific returns the type specific data field.
+func (b ICMPv6) TypeSpecific() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
+}
+
+// SetTypeSpecific sets the type specific data field.
func (b ICMPv6) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index b07d9991d..4c6e4be64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,10 +16,29 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// RFC 971 defines the fields of the IPv4 header on page 11 using the following
+// diagram: ("Figure 4")
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Identification |Flags| Fragment Offset |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Time to Live | Protocol | Header Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Source Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Destination Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options | Padding |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
const (
versIHL = 0
tos = 1
@@ -33,6 +52,7 @@ const (
checksum = 10
srcAddr = 12
dstAddr = 16
+ options = 20
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -76,7 +96,8 @@ type IPv4Fields struct {
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
-// Always call IsValid() to validate an instance of IPv4 before using other methods.
+// Always call IsValid() to validate an instance of IPv4 before using other
+// methods.
type IPv4 []byte
const (
@@ -151,13 +172,44 @@ func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
- return int(b[versIHL] >> 4)
+ return int(b[versIHL] >> ipVersionShift)
}
+// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits
+// of the first byte, and is counted in multiples of 4 bytes.
+//
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// (...)
+// Version: 4 bits
+// The Version field indicates the format of the internet header. This
+// document describes version 4.
+//
+// IHL: 4 bits
+// Internet Header Length is the length of the internet header in 32
+// bit words, and thus points to the beginning of the data. Note that
+// the minimum value for a correct header is 5.
+//
+const (
+ ipVersionShift = 4
+ ipIHLMask = 0x0f
+ IPv4IHLStride = 4
+)
+
// HeaderLength returns the value of the "header length" field of the ipv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
- return (b[versIHL] & 0xf) * 4
+ return (b[versIHL] & ipIHLMask) * IPv4IHLStride
+}
+
+// SetHeaderLength sets the value of the "Internet Header Length" field.
+func (b IPv4) SetHeaderLength(hdrLen uint8) {
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
// ID returns the value of the identifier field of the ipv4 header.
@@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// Options returns a a buffer holding the options.
+func (b IPv4) Options() []byte {
+ hdrLen := b.HeaderLength()
+ return b[options:hdrLen:hdrLen]
+}
+
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.Protocol())
@@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
+// SetTTL sets the "Time to Live" field of the IPv4 header.
+func (b IPv4) SetTTL(v byte) {
+ b[ttl] = v
+}
+
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
@@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 {
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
- b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 0761a1807..c5d8a3456 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -34,6 +34,9 @@ const (
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
+
+ // IPv6FixedHeaderSize is the size of the fixed header.
+ IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -69,7 +72,7 @@ type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
- IPv6MinimumSize = 40
+ IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
@@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool {
return addr[0] != 0xff
}
+const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
- const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// IsSolicitedNodeAddr determines whether the address is a solicited-node
+// multicast address.
+func IsSolicitedNodeAddr(addr tcpip.Address) bool {
+ return solicitedNodeMulticastPrefix == addr[:len(addr)-3]
+}
+
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
//
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 3499d8399..583c2c5d3 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader bytes.Reader
+
+ // optionOffset is the number of bytes from the first byte of the
+ // options field to the beginning of the current option.
+ optionOffset uint32
+
+ // nextOptionOffset is the offset of the next option.
+ nextOptionOffset uint32
+}
+
+// OptionOffset returns the number of bytes parsed while processing the
+// option field of the current Extension Header.
+func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
+ return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
@@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// the options data, or an error occured.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
+ i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
@@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
+ i.nextOptionOffset = i.optionOffset + 1
continue
}
@@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
- // Special-case the variable length padding option to avoid a copy.
- if id == ipv6PadNExtHdrOptionIdentifier {
- // Do we have enough bytes in the reader for the PadN option?
- if n := i.reader.Len(); n < int(length) {
- // Reset the reader to effectively consume the remaining buffer.
- i.reader.Reset(nil)
-
- // We return the same error as if we failed to read a non-padding option
- // so consumers of this iterator don't need to differentiate between
- // padding and non-padding options.
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
- }
+ // Do we have enough bytes in the reader for the next option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
+ switch id {
+ case ipv6PadNExtHdrOptionIdentifier:
+ // Special-case the variable length padding option to avoid a copy.
if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
-
- // End processing of the PadN option and continue processing the buffer as
- // a new option.
continue
- }
-
- bytes := make([]byte, length)
- if n, err := io.ReadFull(&i.reader, bytes); err != nil {
- // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
- // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
- // Length field found in the option.
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ default:
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
-
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
-
- return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
@@ -382,6 +396,29 @@ type IPv6PayloadIterator struct {
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
+
+ // headerOffset is the offset of the beginning of the current extension
+ // header starting from the beginning of the fixed header.
+ headerOffset uint32
+
+ // parseOffset is the byte offset into the current extension header of the
+ // field we are currently examining. It can be added to the header offset
+ // if the absolute offset within the packet is required.
+ parseOffset uint32
+
+ // nextOffset is the offset of the next header.
+ nextOffset uint32
+}
+
+// HeaderOffset returns the offset to the start of the extension
+// header most recently processed.
+func (i IPv6PayloadIterator) HeaderOffset() uint32 {
+ return i.headerOffset
+}
+
+// ParseOffset returns the number of bytes successfully parsed.
+func (i IPv6PayloadIterator) ParseOffset() uint32 {
+ return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
@@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
nextHdrIdentifier: nextHdrIdentifier,
payload: payload.Clone(nil),
// We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
- reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ nextOffset: IPv6FixedHeaderSize,
}
}
@@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occured.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ i.headerOffset = i.nextOffset
+ i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
@@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
return IPv6RoutingExtHdr(bytes), false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
- // We ignore the returned bytes becauase we know the fragment extension
+ // We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
@@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
+ i.parseOffset++
var length uint8
length, err = i.reader.ReadByte()
i.payload.TrimFront(1)
+
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
@@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length = 0
}
+ // Make parseOffset point to the first byte of the Extension Header
+ // specific data.
+ i.parseOffset++
+
+ // length is in 8 byte chunks but doesn't include the first one.
+ // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
+ // in section 4.8 for new extension headers at the top of page 24.
+ // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
+ // units, not including the first 8 octets.
+ i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
+
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if bytes == nil {
bytes = make([]byte, bytesLen)
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
index b5540bf66..17a49d4fa 100644
--- a/pkg/tcpip/header/ipversion_test.go
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -22,7 +22,7 @@ import (
func TestIPv4(t *testing.T) {
b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
+ b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize})
const want = header.IPv4Version
if v := header.IPVersion(b); v != want {
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 522135557..5ca75c834 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -139,6 +139,7 @@ traverseExtensions:
// Returns true if the header was successfully parsed.
func UDP(pkt *stack.PacketBuffer) bool {
_, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
return ok
}
@@ -162,5 +163,6 @@ func TCP(pkt *stack.PacketBuffer) bool {
}
_, ok = pkt.TransportHeader().Consume(hdrLen)
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
return ok
}
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
new file mode 100644
index 000000000..9f31c1ffc
--- /dev/null
+++ b/pkg/tcpip/link/pipe/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = ["pipe.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..76f563811
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ capabilities: capabilities,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ linked: ep1,
+ capabilities: capabilities,
+ }
+ ep1.linked = ep2
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ capabilities stack.LinkEndpointCapabilities
+ linkAddr tcpip.LinkAddress
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ onWritePacket func(*stack.PacketBuffer)
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // The pipe endpoint will accept all multicast/broadcast link traffic and only
+ // unicast traffic destined to itself.
+ if len(e.linked.linkAddr) != 0 &&
+ r.RemoteLinkAddress != e.linked.linkAddr &&
+ r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
+ !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
+ return nil
+ }
+
+ e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index b6ddbe81e..f94491026 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) {
}
}
+// NICID returns the NIC ID of the device.
+//
+// Must only be called after the device has been attached to an endpoint.
+func (d *Device) NICID() tcpip.NICID {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+
+ if d.endpoint == nil {
+ panic("called NICID on a device that has not been attached")
+ }
+
+ return d.endpoint.nicID
+}
+
// SetIff services TUNSETIFF ioctl(2) request.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+//
+// Returns true if a new NIC was created; false if an existing one was attached.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
// Input validations.
@@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
prefix := "tun"
@@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return nil
+ return created, nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
- if nic, found := s.GetNICByName(name); found {
- endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
+ endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, false, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, nil
+ return endpoint, false, nil
}
}
@@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, nil
+ return endpoint, true, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, false, syserror.EINVAL
}
}
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 46083925c..59710352b 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -9,6 +9,7 @@ go_test(
"ip_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -17,6 +18,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
],
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b025bb087..7df77c66e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -18,6 +18,8 @@
package arp
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,38 +35,73 @@ const (
ProtocolAddress = tcpip.Address("arp")
)
-// endpoint implements stack.NetworkEndpoint.
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- protocol *protocol
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ stack.AddressableEndpointState
+
+ protocol *protocol
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
+func (e *endpoint) Enable() *tcpip.Error {
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ e.setEnabled(true)
+ return nil
+}
+
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+func (e *endpoint) setEnabled(v bool) {
+ if v {
+ atomic.StoreUint32(&e.enabled, 1)
+ } else {
+ atomic.StoreUint32(&e.enabled, 0)
+ }
+}
+
+func (e *endpoint) Disable() {
+ e.setEnabled(false)
+}
+
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
func (e *endpoint) DefaultTTL() uint8 {
return 0
}
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
+ lmtu := e.nic.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
@@ -85,6 +122,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
@@ -95,15 +136,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -112,24 +153,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ origSender := h.HardwareAddressSender()
+ r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
- packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.HardwareAddressTarget(), origSender)
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -161,14 +210,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
protocol: p,
- nicID: nicID,
- linkEP: sender,
+ nic: nic,
linkAddrCache: linkAddrCache,
nud: nud,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
@@ -179,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
+ NetProto: ProtocolNumber,
RemoteLinkAddress: remoteLinkAddr,
}
if len(r.RemoteLinkAddress) == 0 {
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 96c5f42f8..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -43,5 +45,8 @@ go_test(
library = ":fragmentation",
deps = [
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6a4843f92..ed502a473 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,7 +13,7 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
@@ -25,12 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
- // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
- DefaultReassembleTimeout = 30 * time.Second
-
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
@@ -81,6 +79,8 @@ type Fragmentation struct {
size int
timeout time.Duration
blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +97,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -110,13 +110,17 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
blockSize = minBlockSize
}
- return &Fragmentation{
+ f := &Fragmentation{
reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
blockSize: blockSize,
+ clock: clock,
}
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
}
// Process processes an incoming fragment belonging to an ID and returns a
@@ -155,15 +159,17 @@ func (f *Fragmentation) Process(
f.mu.Lock()
r, ok := f.reassemblers[id]
- if ok && r.tooOld(f.timeout) {
- // This is very likely to be an id-collision or someone performing a slow-rate attack.
- f.release(r)
- ok = false
- }
if !ok {
- r = newReassembler(id)
+ r = newReassembler(id, f.clock)
f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
}
f.mu.Unlock()
@@ -211,3 +217,102 @@ func (f *Fragmentation) release(r *reassembler) {
f.size = 0
}
}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r)
+ }
+}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ innerMTU int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ innerMTU: innerMTU,
+ fragmentCount: fragmentCount,
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 416604659..d3c7d7f92 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -20,9 +20,16 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
@@ -95,7 +102,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
@@ -131,25 +138,126 @@ func TestFragmentationProcess(t *testing.T) {
}
func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(minBlockSize, 1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err)
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
}
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // sizeAfterEvent is the expected size of the fragmentation instance after
+ // the event.
+ sizeAfterEvent int
+ }
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.size, event.sizeAfterEvent; got != want {
+ t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
}
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
@@ -173,7 +281,7 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
@@ -268,7 +376,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
_, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
@@ -279,3 +387,113 @@ func TestErrors(t *testing.T) {
})
}
}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ innerMTU int
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ innerMTU: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ innerMTU: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ innerMTU: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ innerMTU: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := fragPkt.Size(); got > test.innerMTU {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index f044867dc..9bb051a30 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -40,15 +40,15 @@ type reassembler struct {
deleted int
heap fragHeap
done bool
- creationTime time.Time
+ creationTime int64
}
-func newReassembler(id FragmentID) *reassembler {
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
heap: make(fragHeap, 0, 8),
- creationTime: time.Now(),
+ creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
@@ -116,10 +116,6 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf
return res, r.proto, true, consumed, nil
}
-func (r *reassembler) tooOld(timeout time.Duration) bool {
- return time.Now().Sub(r.creationTime) > timeout
-}
-
func (r *reassembler) checkDoneOrMark() bool {
r.mu.Lock()
prev := r.done
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index dff7c9dcb..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -18,6 +18,8 @@ import (
"math"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
type updateHolesInput struct {
@@ -94,7 +96,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(FragmentID{})
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4640ca95c..d436873b6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -17,6 +17,7 @@ package ip_test
import (
"testing"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -25,26 +26,35 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- nicID = 1
+ localIPv4Addr = "\x0a\x00\x00\x01"
+ remoteIPv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
+var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv4Addr,
+ PrefixLen: 24,
+}
+
+var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv6Addr,
+ PrefixLen: 120,
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -225,7 +235,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack(t *testing.T) *stack.Stack {
+func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
t.Helper()
s := stack.New(stack.Options{
@@ -237,22 +247,278 @@ func buildDummyStack(t *testing.T) *stack.Stack {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err)
+ v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
}
+ return s, e
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s, _ := buildDummyStackWithLinkEndpoint(t)
return s
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ testObject
+
+ mu struct {
+ sync.RWMutex
+ disabled bool
+ }
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (t *testInterface) Enabled() bool {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return !t.mu.disabled
+}
+
+func (t *testInterface) setEnabled(v bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mu.disabled = !v
+}
+
+func TestSourceAddressValidation(t *testing.T) {
+ rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ rxICMP func(*channel.Endpoint, tcpip.Address)
+ valid bool
+ }{
+ {
+ name: "IPv4 valid",
+ srcAddress: "\x01\x02\x03\x04",
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 valid",
+ srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddress: "\xe0\x00\x00\x01",
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ rxICMP: rxIPv6ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 broadcast",
+ srcAddress: header.IPv4Broadcast,
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 subnet broadcast",
+ srcAddress: func() tcpip.Address {
+ subnet := localIPv4AddrWithPrefix.Subnet()
+ return subnet.Broadcast()
+ }(),
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t)
+ test.rxICMP(e, test.srcAddress)
+
+ var wantValid uint64
+ if test.valid {
+ wantValid = 1
+ }
+
+ if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
+ t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
+ t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
+ }
+ })
+ }
+}
+
+func TestEnableWhenNICDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protocolFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protocolFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var nic testInterface
+ nic.setEnabled(false)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
+ })
+ p := s.NetworkProtocolInstance(test.protoNum)
+
+ // We pass nil for all parameters except the NetworkInterface and Stack
+ // since Enable only depends on these.
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
+
+ // The endpoint should initially be disabled, regardless the NIC's enabled
+ // status.
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Attempting to enable the endpoint while the NIC is disabled should
+ // fail.
+ nic.setEnabled(false)
+ if err := ep.Enable(); err != tcpip.ErrNotPermitted {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ }
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status so we "enable" the NIC to read just the endpoint's
+ // enabled status.
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Enabling the interface after the NIC has been enabled should succeed.
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ if !ep.Enabled() {
+ t.Fatal("got ep.Enabled() = false, want = true")
+ }
+
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status.
+ nic.setEnabled(false)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Disabling the endpoint when the NIC is enabled should make the endpoint
+ // disabled.
+ nic.setEnabled(true)
+ ep.Disable()
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ })
+ }
+}
+
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -268,12 +534,12 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -287,12 +553,21 @@ func TestIPv4Send(t *testing.T) {
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv4(view)
@@ -301,8 +576,8 @@ func TestIPv4Receive(t *testing.T) {
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
@@ -311,12 +586,12 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -327,8 +602,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -352,18 +627,26 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
t.Fatal(err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
@@ -375,7 +658,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
+ DstAddr: localIPv4Addr,
})
// Create the ICMP header.
@@ -393,8 +676,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
+ SrcAddr: localIPv4Addr,
+ DstAddr: remoteIPv4Addr,
})
// Make payload be non-zero.
@@ -404,28 +687,37 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 24
frag1 := buffer.NewView(totalLen)
@@ -437,8 +729,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -453,8 +745,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -462,12 +754,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -480,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
@@ -492,18 +784,26 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
for i := 0; i < len(payload); i++ {
@@ -517,12 +817,12 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -536,12 +836,20 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv6(view)
@@ -549,8 +857,8 @@ func TestIPv6Receive(t *testing.T) {
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -559,12 +867,12 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -576,8 +884,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -608,7 +916,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r, err := buildIPv6Route(
- localIpv6Addr,
+ localIPv6Addr,
"\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
)
if err != nil {
@@ -616,12 +924,20 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
@@ -635,7 +951,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -651,8 +967,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
})
// Build the fragmentation header if needed.
@@ -674,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index f9c2aa980..7fc12e229 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -27,12 +28,15 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3e5cf2ad9..3407755ed 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,8 @@
package ipv4
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -40,7 +42,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
// Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
+ headerChecksum := h.Checksum()
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
+ calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(headerChecksum)
+ if calculatedChecksum != headerChecksum {
+ // It's possible that a raw socket still expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
- // Make a copy of data before pkt gets sent to raw socket.
- // DeliverTransportPacket will take ownership of pkt.
- replyData := pkt.Data.Clone(nil)
- replyData.TrimFront(header.ICMPv4MinimumSize)
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+ replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
@@ -110,39 +110,56 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
}
defer r.Release()
- // Use the remote link address from the incoming packet.
- r.ResolveWith(remoteLinkAddr)
-
- // Prepare a reply packet.
- icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize)
- copy(icmpHdr, h)
- icmpHdr.SetType(header.ICMPv4EchoReply)
- icmpHdr.SetChecksum(0)
- icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0)))
- dataVV := buffer.View(icmpHdr).ToVectorisedView()
- dataVV.Append(replyData)
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket().
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: dataVV,
+ Data: replyVV,
})
- // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
- // information we will have to change this code to handle the ICMP header
- // no longer being in the data buffer.
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // Send out the reply packet.
+
+ // The checksum will be calculated so we don't need to do it here.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
- Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, replyPkt); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- sent := r.Stats().ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
@@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Assume any type we don't know about may be an error type.
return nil
}
- } else if transportHeader.IsEmpty() {
- return nil
}
// Now work out how much of the triggering packet we should return.
@@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv4MinimumProcessableDatagramSize {
mtu = header.IPv4MinimumProcessableDatagramSize
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
@@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
ReserveHeaderBytes: headerLen,
Data: payload,
})
+
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter := sent.DstUnreachable
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
- if err := r.WritePacket(
+ if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
+ TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 254d66147..c5ac7b8b5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -18,7 +18,9 @@ package ipv4
import (
"fmt"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -29,6 +31,15 @@ import (
)
const (
+ // As per RFC 791 section 3.2:
+ // The current recommendation for the initial timer setting is 15 seconds.
+ // This may be changed as experience with this protocol accumulates.
+ //
+ // Considering that it is an old recommendation, we use the same reassembly
+ // timeout that linux defines, which is 30 seconds:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
+ reassembleTimeout = 30 * time.Second
+
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -47,22 +58,113 @@ const (
fragmentblockSize = 8
)
+var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
+
+ // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ }
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
dispatcher: dispatcher,
protocol: p,
- stack: st,
+ }
+ e.mu.addressableEndpointState.Init(e)
+ return e
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ return err
+ }
+ // We have no need for the address endpoint.
+ ep.DecRef()
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ return err
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
+ }
+
+ // The address may have already been removed.
+ if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
}
@@ -74,31 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
-}
-
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -106,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is already present in pkt.NetworkHeader.
-// pkt.TransportHeader may be set. mtu includes the IP header and options. This
-// does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.NetworkHeader().View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
-
- // Keep the length reserved for link-layer, we need to create fragments with
- // the same reserved length.
- reservedForLink := pkt.AvailableHeaderBytes()
-
- // Destroy the packet, pull all payloads out for fragmentation.
- transHeader, data := pkt.TransportHeader().View(), pkt.Data
-
- // Where possible, the first fragment that is sent has the same
- // number of bytes reserved for header as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- transFitsFirst := len(transHeader) <= innerMTU
-
- for i := 0; i < n; i++ {
- reserve := reservedForLink + int(ip.HeaderLength())
- if i == 0 && transFitsFirst {
- // Reserve for transport header if it's going to be put in the first
- // fragment.
- reserve += len(transHeader)
- }
- fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: reserve,
- })
- fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-
- // Copy data for the fragment.
- avail := innerMTU
-
- if n := len(transHeader); n > 0 {
- if n > avail {
- n = avail
- }
- if i == 0 && transFitsFirst {
- copy(fragPkt.TransportHeader().Push(n), transHeader)
- } else {
- fragPkt.Data.AppendView(transHeader[:n:n])
- }
- transHeader = transHeader[n:]
- avail -= n
- }
-
- if avail > 0 {
- n := data.Size()
- if n > avail {
- n = avail
- }
- data.ReadToVV(&fragPkt.Data, n)
- avail -= n
- }
-
- copied := uint16(innerMTU - avail)
-
- // Set lengths in header and calculate checksum.
- h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
- copy(h, ip)
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + copied)
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += copied
-
- // Send out the fragment.
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+// writePacketFragments fragments pkt and writes the results on the link
+// endpoint. The IP header must already present in the original packet. The mtu
+// is the maximum size of the packets.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
r.Stats().IP.PacketsSent.Increment()
+ if !more {
+ break
+ }
}
+
return nil
}
@@ -219,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -228,8 +240,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -245,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -261,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -285,16 +298,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -308,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -317,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -377,23 +394,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 1122 section 3.2.1.3:
+ // When a host sends any datagram, the IP source address MUST
+ // be one of its own IP addresses (but not a broadcast or
+ // multicast address).
+ if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -449,6 +482,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
}
+
+ r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
// TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
@@ -458,7 +493,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleICMP(r, pkt)
return
}
- r.Stats().IP.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
@@ -468,24 +502,145 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
// Close cleans up resources associated with the endpoint.
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.disableLocked()
+ e.mu.addressableEndpointState.Cleanup()
+}
+
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ loopback := e.nic.IsLoopback()
+ addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ // IPv4 has a notion of a subnet broadcast address and considers the
+ // loopback interface bound to an address's whole subnet (on linux).
+ return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
+ })
+ if addressEndpoint != nil {
+ return addressEndpoint
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
+ if err != nil {
+ // AddAddress only returns an error if the address is already assigned,
+ // but we just checked above if the address exists so we expect no error.
+ panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
+ }
+ return addressEndpoint
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV4MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
- ids []uint32
- hashIV uint32
+ stack *stack.Stack
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ ids []uint32
+ hashIV uint32
+
fragmentation *fragmentation.Fragmentation
}
@@ -558,6 +713,20 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ if v {
+ atomic.StoreUint32(&p.forwarding, 1)
+ } else {
+ atomic.StoreUint32(&p.forwarding, 0)
+ }
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -567,19 +736,41 @@ func calculateMTU(mtu uint32) uint32 {
return mtu - header.IPv4MinimumSize
}
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ return mtu
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -590,9 +781,33 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
}
}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
+ }
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 0b3ed9483..9916d783f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,19 +16,24 @@ package ipv4_test
import (
"bytes"
+ "context"
"encoding/hex"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ options []byte
+ }{
+ {
+ name: "valid",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPv4Options(test.options),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
+ }
+}
+
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if i == 0 {
- got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
- // sourcePacketInfo does not have NetworkHeader added, simulate one.
- want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
- // Check that it kept the transport header in packet.TransportHeader if
- // it fits in the first fragment.
- if want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- }
if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
@@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
func TestFragmentation(t *testing.T) {
+ const ttl = 42
+
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
@@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) {
payloadViewsSizes []int
expectedFrags int
}{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
}
for _, ft := range fragTests {
@@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) {
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("got err = %s, want = nil", err)
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
if got := len(ep.WrittenPackets); got != ft.expectedFrags {
@@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) {
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
@@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) {
// TestFragmentationErrors checks that errors are returned from write packet
// correctly.
func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ expectedError := tcpip.ErrAborted
fragTests := []struct {
description string
mtu uint32
transportHeaderLength int
- payloadViewsSizes []int
- err *tcpip.Error
+ payloadSize int
allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
+ {
+ description: "No frag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on first frag MTU smaller than header",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.err {
- t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
+ if err != expectedError {
+ t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ }
})
}
}
@@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
+
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
@@ -1040,7 +1345,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to find filter table")
}
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
@@ -1062,10 +1367,10 @@ func TestWriteStats(t *testing.T) {
}
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
@@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\x10\x00\x00\x02"
)
if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
}
return rt
}
@@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 8bd8f5c52..a30437f02 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -5,16 +5,20 @@ package(licenses = ["notice"])
go_library(
name = "ipv6",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "ndp.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -38,6 +42,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
index d199ded6a..09ba133b1 100644
--- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
+++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
@@ -14,7 +14,7 @@
// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
-package stack
+package ipv6
import "strconv"
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index dd3295b31..a454f6c34 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -39,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -207,14 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- s := r.Stack()
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now, drop this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
// attempting to perform address resolution on the target. In this case,
@@ -227,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if r.RemoteAddress == header.IPv6Any {
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
}
// Do not handle neighbor solicitations targeted to an address that is
@@ -240,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -275,7 +283,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ }
+
+ // As per RFC 4861 section 7.1.1:
+ // A node MUST silently discard any received Neighbor Solicitation
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP source address is the unspecified address, the IP
+ // destination address is a solicited-node multicast address.
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ received.Invalid.Increment()
+ return
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -314,18 +333,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
}
+ neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
})
- packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
+ na.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -342,7 +361,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.Invalid.Increment()
return
}
@@ -353,20 +372,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
targetAddr := na.TargetAddress()
- s := r.Stack()
-
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now short-circuit this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ //
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
return
}
@@ -396,7 +421,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -415,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := r.LocalAddress
@@ -424,16 +447,13 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
}
defer r.Release()
- // Use the link address from the source of the original packet.
- r.ResolveWith(remoteLinkAddr)
-
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data,
@@ -568,9 +588,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
}
- // Tell the NIC to handle the RA.
- stack := r.Stack()
- stack.HandleNDPRA(e.nicID, routerAddr, ra)
+ e.mu.Lock()
+ e.mu.ndp.handleRA(routerAddr, ra)
+ e.mu.Unlock()
case header.ICMPv6RedirectMsg:
// TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
@@ -600,18 +620,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
}
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
-
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-
- icmpV6FlagOffset = 4
- icmpV6OptOffset = 24
- icmpV6LengthOffset = 25
-)
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -621,31 +629,37 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- snaddr := header.SolicitedNodeAddr(addr)
-
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
// route here. Note, we would need the nicID to do this properly so the right
// NIC (associated to linkEP) is used to send the NDP NS message.
- r := &stack.Route{
+ r := stack.Route{
LocalAddress: localAddr,
- RemoteAddress: snaddr,
+ RemoteAddress: addr,
RemoteLinkAddress: remoteLinkAddr,
}
+
+ // If a remote address is not already known, then send a multicast
+ // solicitation since multicast addresses have a static mapping to link
+ // addresses.
if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ r.RemoteAddress = header.SolicitedNodeAddr(addr)
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
}
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ }
+ neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
})
- icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- icmpHdr.SetType(header.ICMPv6NeighborSolicit)
- copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
- icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
- icmpHdr[icmpV6LengthOffset] = 1
- copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
+ packet.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns.SetTargetAddress(addr)
+ ns.Options().Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
@@ -658,7 +672,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -676,6 +690,36 @@ type icmpReason interface {
isICMPReason()
}
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
type icmpReasonPortUnreachable struct{}
@@ -684,18 +728,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- stats := r.Stats().ICMP
- sent := stats.V6PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
- // TODO(b/164522993) There are exceptions to this rule.
+ // There are exceptions to this rule.
// See: point e.3) RFC 4443 section-2.4
//
// (e) An ICMPv6 error message MUST NOT be originated as a result of
@@ -713,7 +750,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any {
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return nil
}
@@ -743,11 +805,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// packet that caused the error) as possible without making
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv6MinimumMTU {
mtu = header.IPv6MinimumMTU
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
available := int(mtu) - headerLen
if available < header.IPv6MinimumSize {
return nil
@@ -766,12 +828,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
- counter := sent.DstUnreachable
- err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
- if err != nil {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
sent.Dropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index dd58022d6..3affcc4e4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,21 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -39,6 +43,9 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
@@ -50,6 +57,10 @@ type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
// Indicate that resolution for link layer addresses is required to send
// packets over this link. This is needed so the NIC knows to allocate a
@@ -103,6 +114,28 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin
func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -150,9 +183,13 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -288,9 +325,13 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -1187,24 +1228,30 @@ func TestLinkAddressRequest(t *testing.T) {
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectedLinkAddr tcpip.LinkAddress
+ expectedAddr tcpip.Address
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectLinkAddr: linkAddr1,
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectedLinkAddr: linkAddr1,
+ expectedAddr: lladdr0,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectLinkAddr: mcaddr,
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectedLinkAddr: mcaddr,
+ expectedAddr: snaddr,
},
}
for _, test := range tests {
- p := NewProtocol(nil)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
linkRes, ok := p.(stack.LinkAddressResolver)
if !ok {
t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
@@ -1219,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
+ if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ }
+ if pkt.Route.RemoteAddress != test.expectedAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ }
+ if pkt.Route.LocalAddress != lladdr1 {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ }
+}
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
}
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e436c6a9e..2bd8f4ece 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,18 +16,33 @@
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
+ "sort"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // As per RFC 8200 section 4.5:
+ // If insufficient fragments are received to complete reassembly of a packet
+ // within 60 seconds of the reception of the first-arriving fragment of that
+ // packet, reassembly of that packet must be abandoned.
+ //
+ // Linux also uses 60 seconds for reassembly timeout:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
+ reassembleTimeout = 60 * time.Second
+
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -38,16 +53,306 @@ const (
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+var _ stack.NDPEndpoint = (*endpoint)(nil)
+var _ NDPEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
+
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ ndp ndpState
+ }
+}
+
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it is passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
+// InvalidateDefaultRouter implements stack.NDPEndpoint.
+func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.invalidateDefaultRouter(rtr)
+}
+
+// SetNDPConfigurations implements NDPEndpoint.
+func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
+ c.validate()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.configs = c
+}
+
+// hasTentativeAddr returns true if addr is tentative on e.
+func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
+ e.mu.RLock()
+ addressEndpoint := e.getAddressRLocked(addr)
+ e.mu.RUnlock()
+ return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform e that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected removes the tentative address if it exists. If the
+// address was generated via SLAAC, an attempt is made to generate a new
+// address.
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return tcpip.ErrBadAddress
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := addressEndpoint.AddressWithPrefix().Subnet()
+
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
+}
+
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.Enabled() {
+ return
+ }
+
+ if forwarding {
+ // When transitioning into an IPv6 router, host-only state (NDP discovered
+ // routers, discovered on-link prefixes, and auto-generated addresses) is
+ // cleaned up/invalidated and NDP router solicitations are stopped.
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(true /* hostOnly */)
+ } else {
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started.
+ e.mu.ndp.startSolicitingRouters()
+ }
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the endpoint's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the endpoint
+ // was last enabled, other devices may have acquired the same addresses.
+ var err *tcpip.Error
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if !header.IsV6UnicastAddress(addr) {
+ return true
+ }
+
+ switch addressEndpoint.GetKind() {
+ case stack.Permanent:
+ addressEndpoint.SetKind(stack.PermanentTentative)
+ fallthrough
+ case stack.PermanentTentative:
+ err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint)
+ return err == nil
+ default:
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyway.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !e.protocol.Forwarding() {
+ e.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.stopDADForPermanentAddressesLocked()
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
+ }
+}
+
+// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) stopDADForPermanentAddressesLocked() {
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ return true
+ })
}
// DefaultTTL is the default hop limit for this endpoint.
@@ -58,31 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
@@ -96,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets. The transport
+// header protocol number is required to avoid parsing the IPv6 extension
+// headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ if fragMTU < pkt.TransportHeader().View().Size() {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ break
+ }
+ }
+
+ return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -105,8 +429,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -122,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -143,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
+
r.Stats().IP.PacketsSent.Increment()
return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -161,18 +500,38 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
+ if e.packetMustBeFragmented(pb, gso) {
+ current := pb
+ _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(current, fragPkt)
+ current = current.Next()
+ return nil
+ })
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // The fragmented packet can be released. The rest of the packets can be
+ // processed.
+ pkts.Remove(pb)
+ pb = current
+ }
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -186,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -195,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -219,12 +579,24 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 4291 section 2.7:
+ // Multicast addresses must not be used as source addresses in IPv6
+ // packets or appear in any Routing header.
+ if header.IsV6MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -236,15 +608,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
hasFragmentHeader := false
// iptables filtering. All packets that reach here are intended for
- // this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ // this machine and need not be forwarded.
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
return
}
- for firstHeader := true; ; firstHeader = false {
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -258,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -284,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -301,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -445,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -470,13 +873,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
+ r.Stats().IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -485,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -504,19 +929,343 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
// Close cleans up resources associated with the endpoint.
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.disableLocked()
+ e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
+ e.stopDADForPermanentAddressesLocked()
+ e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
+
+ e.protocol.forgetEndpoint(e)
+}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ // TODO(b/169350103): add checks here after making sure we no longer receive
+ // an empty address.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+}
+
+// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
+// with locking requirements.
+//
+// addAndAcquirePermanentAddressLocked also joins the passed address's
+// solicited-node multicast group and start duplicate address detection.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err != nil {
+ return nil, err
+ }
+
+ if !header.IsV6UnicastAddress(addr.Address) {
+ return addressEndpoint, nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ return nil, err
+ }
+
+ addressEndpoint.SetKind(stack.PermanentTentative)
+
+ if e.Enabled() {
+ if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil {
+ return nil, err
+ }
+ }
+
+ return addressEndpoint, nil
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return e.removePermanentEndpointLocked(addressEndpoint, true)
+}
+
+// removePermanentEndpointLocked is like removePermanentAddressLocked except
+// it works with a stack.AddressEndpoint.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := addressEndpoint.AddressWithPrefix()
+ unicast := header.IsV6UnicastAddress(addr.Address)
+ if unicast {
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch addressEndpoint.ConfigType() {
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case stack.AddressConfigSlaacTemp:
+ e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
+ return err
+ }
+
+ if !unicast {
+ return nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ return nil
+}
+
+// hasPermanentAddressLocked returns true if the endpoint has a permanent
+// address equal to the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return false
+ }
+ return addressEndpoint.GetKind().IsPermanent()
+}
+
+// getAddressRLocked returns the endpoint for the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
+}
+
+// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with
+// locking requirements.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ // addrCandidate is a candidate for Source Address Selection, as per
+ // RFC 6724 section 5.
+ type addrCandidate struct {
+ addressEndpoint stack.AddressEndpoint
+ scope header.IPv6AddressScope
+ }
+
+ if len(remoteAddr) == 0 {
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ var cs []addrCandidate
+ e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !addressEndpoint.IsAssigned(allowExpired) {
+ return
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, addrCandidate{
+ addressEndpoint: addressEndpoint,
+ scope: scope,
+ })
+ })
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return true
+ }
+ if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if c.addressEndpoint.IncRef() {
+ return c.addressEndpoint
+ }
+ }
+
+ return nil
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV6MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
+
type protocol struct {
+ stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ eps map[*endpoint]struct{}
+ }
+
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
- defaultTTL uint32
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
+ defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
fragmentation *fragmentation.Fragmentation
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
+ ndpConfigs NDPConfigurations
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // autoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -541,16 +1290,35 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
+ e.mu.addressableEndpointState.Init(e)
+ e.mu.ndp = ndpState{
+ ep: e,
+ configs: p.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ e.mu.ndp.initializeTempAddrState()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.mu.eps[e] = struct{}{}
+ return e
+}
+
+func (p *protocol) forgetEndpoint(e *endpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, e)
}
// SetOption implements NetworkProtocol.SetOption.
@@ -601,6 +1369,35 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&p.forwarding, 1) == 0
+ }
+ return atomic.SwapUint32(&p.forwarding, 0) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for ep := range p.mu.eps {
+ ep.transitionForwarding(v)
+ }
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -611,10 +1408,144 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-// NewProtocol returns an IPv6 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
- return &protocol{
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+// Options holds options to configure a new protocol.
+type Options struct {
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address is
+ // assigned right away, or at all. If Duplicate Address Detection is enabled,
+ // an address is only assigned if it successfully resolves. If it fails, no
+ // further attempts are made to auto-generate an IPv6 link-local adddress.
+ //
+ // The generated link-local address follows RFC 4291 Appendix A guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// NewProtocolWithOptions returns an IPv6 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ opts.NDPConfigs.validate()
+
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
+
+ ndpDisp: opts.NDPDisp,
+ ndpConfigs: opts.NDPConfigs,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ }
+ p.mu.eps = make(map[*endpoint]struct{})
+ p.SetDefaultTTL(DefaultTTL)
+ return p
}
}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
+}
+
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // MTU because they should only be transmitted once.
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ mtu -= header.IPv6FragmentHeaderSize
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 8ae146c5e..bee18d1a8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,17 +15,21 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
"math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -53,8 +57,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
t.Helper()
// Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
@@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ result := reassembledPayload.ToView()
+ if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
@@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
protocolFactory stack.TransportProtocolFactory
@@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -336,9 +419,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -348,12 +432,58 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -364,39 +494,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
@@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -683,7 +972,6 @@ type fragmentData struct {
func TestReceiveIPv6Fragments(t *testing.T) {
const (
- nicID = 1
udpPayload1Length = 256
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
@@ -1748,7 +2036,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to find filter table")
}
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
@@ -1770,10 +2058,10 @@ func TestWriteStats(t *testing.T) {
}
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
@@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
-
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
}
return rt
}
@@ -1895,3 +2183,320 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if !hasEP {
+ t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
+ }
+ }
+}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+type fragmentationTestCase struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ extraHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ expectedFrags int
+}
+
+var fragmentationTests = []fragmentationTestCase{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso nil",
+ mtu: 1280,
+ gso: nil,
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header and prependable bytes",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 20,
+ extraHdrLen: header.IPv6MinimumSize + 66,
+ payloadSize: 1500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 296, more: false},
+ },
+ },
+}
+
+func TestFragmentation(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if len(ep.WrittenPackets) > 0 {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on packet with MTU smaller than transport header",
+ mtu: 1280,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 97ca00d16..40da011f8 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package ipv6
import (
"fmt"
@@ -23,9 +23,27 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
// defaultDupAddrDetectTransmits is the default number of NDP Neighbor
// Solicitation messages to send when doing Duplicate Address Detection
// for a tentative address.
@@ -34,7 +52,7 @@ const (
defaultDupAddrDetectTransmits = 1
// defaultMaxRtrSolicitations is the default number of Router
- // Solicitation messages to send when a NIC becomes enabled.
+ // Solicitation messages to send when an IPv6 endpoint becomes enabled.
//
// Default = 3 (from RFC 4861 section 10).
defaultMaxRtrSolicitations = 3
@@ -131,7 +149,7 @@ const (
minRegenAdvanceDuration = time.Duration(0)
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
- // SLAAC address regenerations in response to a NIC-local conflict.
+ // SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
)
@@ -163,7 +181,7 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be preferred for at
+ // This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -173,11 +191,17 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be valid for at least
+ // This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
)
+// NDPEndpoint is an endpoint that supports NDP.
+type NDPEndpoint interface {
+ // SetNDPConfigurations sets the NDP configurations.
+ SetNDPConfigurations(NDPConfigurations)
+}
+
// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
// NDP Router Advertisement informed the Stack about.
type DHCPv6ConfigurationFromNDPRA int
@@ -192,7 +216,7 @@ const (
// DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
//
// DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
- // will return all available configuration information.
+ // returns all available configuration information when serving addresses.
DHCPv6ManagedAddress
// DHCPv6OtherConfigurations indicates that other configuration information is
@@ -207,19 +231,18 @@ const (
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
- // OnDuplicateAddressDetectionStatus will be called when the DAD process
- // for an address (addr) on a NIC (with ID nicID) completes. resolved
- // will be set to true if DAD completed successfully (no duplicate addr
- // detected); false otherwise (addr was detected to be a duplicate on
- // the link the NIC is a part of, or it was stopped for some other
- // reason, such as the address being removed). If an error occured
- // during DAD, err will be set and resolved must be ignored.
+ // OnDuplicateAddressDetectionStatus is called when the DAD process for an
+ // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
+ // if DAD completed successfully (no duplicate addr detected); false otherwise
+ // (addr was detected to be a duplicate on the link the NIC is a part of, or
+ // it was stopped for some other reason, such as the address being removed).
+ // If an error occured during DAD, err is set and resolved must be ignored.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
- // OnDefaultRouterDiscovered will be called when a new default router is
+ // OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
// router should be remembered.
//
@@ -227,56 +250,55 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
- // OnDefaultRouterInvalidated will be called when a discovered default
- // router that was remembered is invalidated.
+ // OnDefaultRouterInvalidated is called when a discovered default router that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
- // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
- // discovered. Implementations must return true if the newly discovered
- // on-link prefix should be remembered.
+ // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
+ // Implementations must return true if the newly discovered on-link prefix
+ // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
- // OnOnLinkPrefixInvalidated will be called when a discovered on-link
- // prefix that was remembered is invalidated.
+ // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
- // OnAutoGenAddress will be called when a new prefix with its
- // autonomous address-configuration flag set has been received and SLAAC
- // has been performed. Implementations may prevent the stack from
- // assigning the address to the NIC by returning false.
+ // OnAutoGenAddress is called when a new prefix with its autonomous address-
+ // configuration flag set is received and SLAAC was performed. Implementations
+ // may prevent the stack from assigning the address to the NIC by returning
+ // false.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
- // OnAutoGenAddressDeprecated will be called when an auto-generated
- // address (as part of SLAAC) has been deprecated, but is still
- // considered valid. Note, if an address is invalidated at the same
- // time it is deprecated, the deprecation event MAY be omitted.
+ // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
+ // is deprecated, but is still considered valid. Note, if an address is
+ // invalidated at the same ime it is deprecated, the deprecation event may not
+ // be received.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnAutoGenAddressInvalidated will be called when an auto-generated
- // address (as part of SLAAC) has been invalidated.
+ // OnAutoGenAddressInvalidated is called when an auto-generated address
+ // (SLAAC) is invalidated.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnRecursiveDNSServerOption will be called when an NDP option with
- // recursive DNS servers has been received. Note, addrs may contain
- // link-local addresses.
+ // OnRecursiveDNSServerOption is called when the stack learns of DNS servers
+ // through NDP. Note, the addresses may contain link-local addresses.
//
// It is up to the caller to use the DNS Servers only for their valid
// lifetime. OnRecursiveDNSServerOption may be called for new or
@@ -288,8 +310,8 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
- // OnDNSSearchListOption will be called when an NDP option with a DNS
- // search list has been received.
+ // OnDNSSearchListOption is called when the stack learns of DNS search lists
+ // through NDP.
//
// It is up to the caller to use the domain names in the search list
// for only their valid lifetime. OnDNSSearchListOption may be called
@@ -298,8 +320,8 @@ type NDPDispatcher interface {
// be increased, decreased or completely invalidated when lifetime = 0.
OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
- // OnDHCPv6Configuration will be called with an updated configuration that is
- // available via DHCPv6 for a specified NIC.
+ // OnDHCPv6Configuration is called with an updated configuration that is
+ // available via DHCPv6 for the passed NIC.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
@@ -320,7 +342,7 @@ type NDPConfigurations struct {
// Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
- // The number of Router Solicitation messages to send when the NIC
+ // The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
MaxRtrSolicitations uint8
@@ -335,24 +357,22 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements will be
- // processed.
+ // HandleRAs determines whether or not Router Advertisements are processed.
HandleRAs bool
- // DiscoverDefaultRouters determines whether or not default routers will
- // be discovered from Router Advertisements. This configuration is
- // ignored if HandleRAs is false.
+ // DiscoverDefaultRouters determines whether or not default routers are
+ // discovered from Router Advertisements, as per RFC 4861 section 6. This
+ // configuration is ignored if HandleRAs is false.
DiscoverDefaultRouters bool
- // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
- // will be discovered from Router Advertisements' Prefix Information
- // option. This configuration is ignored if HandleRAs is false.
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
+ // discovered from Router Advertisements' Prefix Information option, as per
+ // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
DiscoverOnLinkPrefixes bool
- // AutoGenGlobalAddresses determines whether or not global IPv6
- // addresses will be generated for a NIC in response to receiving a new
- // Prefix Information option with its Autonomous Address
- // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
+ // SLAAC to auto-generate global SLAAC addresses in response to Prefix
+ // Information options, as per RFC 4862.
//
// Note, if an address was already generated for some unique prefix, as
// part of SLAAC, this option does not affect whether or not the
@@ -366,12 +386,12 @@ type NDPConfigurations struct {
//
// If the method used to generate the address does not support creating
// alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
- // MAC address), then no attempt will be made to resolve the conflict.
+ // MAC address), then no attempt is made to resolve the conflict.
AutoGenAddressConflictRetries uint8
// AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
- // addresses will be generated for a NIC as part of SLAAC privacy extensions,
- // RFC 4941.
+ // addresses are generated for an IPv6 endpoint as part of SLAAC privacy
+ // extensions, as per RFC 4941.
//
// Ignored if AutoGenGlobalAddresses is false.
AutoGenTempGlobalAddresses bool
@@ -410,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
}
// validate modifies an NDPConfigurations with valid values. If invalid values
-// are present in c, the corresponding default values will be used instead.
+// are present in c, the corresponding default values are used instead.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
@@ -439,8 +459,8 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
- // The NIC this ndpState is for.
- nic *NIC
+ // The IPv6 endpoint this ndpState is for.
+ ep *endpoint
// configs is the per-interface NDP configurations.
configs NDPConfigurations
@@ -458,8 +478,8 @@ type ndpState struct {
// Used to let the Router Solicitation timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this ndpState is associated with. MUST be set when the timer is
- // set.
+ // the IPv6 endpoint this ndpState is associated with. MUST be set when the
+ // timer is set.
done *bool
}
@@ -492,7 +512,7 @@ type dadState struct {
// Used to let the DAD timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this dadState is associated with.
+ // the IPv6 endpoint this dadState is associated with.
done *bool
}
@@ -537,7 +557,7 @@ type tempSLAACAddrState struct {
// The address's endpoint.
//
// Must not be nil.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
// Has a new temporary SLAAC address already been regenerated?
regenerated bool
@@ -567,10 +587,10 @@ type slaacPrefixState struct {
//
// May only be nil when the address is being (re-)generated. Otherwise,
// must not be nil as all SLAAC prefixes must have a stable address.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
- // The number of times an address has been generated locally where the NIC
- // already had the generated address.
+ // The number of times an address has been generated locally where the IPv6
+ // endpoint already had the generated address.
localGenerationFailures uint8
}
@@ -578,11 +598,12 @@ type slaacPrefixState struct {
tempAddrs map[tcpip.Address]tempSLAACAddrState
// The next two fields are used by both stable and temporary addresses
- // generated for a SLAAC prefix. This is safe as only 1 address will be
- // in the generation and DAD process at any time. That is, no two addresses
- // will be generated at the same time for a given SLAAC prefix.
+ // generated for a SLAAC prefix. This is safe as only 1 address is in the
+ // generation and DAD process at any time. That is, no two addresses are
+ // generated at the same time for a given SLAAC prefix.
- // The number of times an address has been generated and added to the NIC.
+ // The number of times an address has been generated and added to the IPv6
+ // endpoint.
//
// Addresses may be regenerated in reseponse to a DAD conflicts.
generationAttempts uint8
@@ -597,16 +618,16 @@ type slaacPrefixState struct {
// This function must only be called by IPv6 addresses that are currently
// tentative.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
return tcpip.ErrAddressFamilyNotSupported
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -617,18 +638,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// existed, we would get an error since we attempted to add a duplicate
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
- // See NIC.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ // See endpoint.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
if remaining == 0 {
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
@@ -637,25 +658,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
var done bool
var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the NIC's lock. This is effectively the same
- // as starting a goroutine but we use a timer that fires immediately so we can
- // reset it for the next DAD iteration.
- timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
if done {
// If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the NIC lock and stopped DAD
- // before this function obtained the NIC lock. Simply return here and do
- // nothing further.
+ // another goroutine already obtained the IPv6 endpoint lock and stopped
+ // DAD before this function obtained the NIC lock. Simply return here and
+ // do nothing further.
return
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -663,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
var err *tcpip.Error
if !dadDone {
// Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
// Do not hold the lock when sending packets which may be a long running
// task or may block link address resolution. We know this is safe
// because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the NIC. Note, DAD would be
- // stopped if the NIC was disabled or removed, or if the address was
- // removed.
- ndp.nic.mu.Unlock()
- err = ndp.sendDADPacket(addr, ref)
- ndp.nic.mu.Lock()
+ // has been stopped before doing any work with the IPv6 endpoint. Note,
+ // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
+ // the address was removed.
+ ndp.ep.mu.Unlock()
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ ndp.ep.mu.Lock()
+ addressEndpoint.DecRef()
}
if done {
// If we reach this point, it means that DAD was stopped after we released
- // the NIC's read lock and before we obtained the write lock.
+ // the IPv6 endpoint's read lock and before we obtained the write lock.
return
}
if dadDone {
// DAD has resolved.
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
} else if err == nil {
// DAD is not done and we had no errors when sending the last NDP NS,
// schedule the next DAD timer.
remaining--
- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ timer.Reset(ndp.configs.RetransmitTimer)
return
}
@@ -698,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// integrator know DAD has completed.
delete(ndp.dad, addr)
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
// temporary SLAAC address.
- if dadDone && ref.configType == slaac {
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
}
})
@@ -722,28 +744,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
// addr.
//
-// addr must be a tentative IPv6 address on ndp's NIC.
+// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
//
-// The NIC ndp belongs to MUST NOT be locked.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is not
+ // locked, the NIC could have been disabled or removed by another goroutine.
if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -752,17 +777,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
ns.SetTargetAddress(addr)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- pkt := NewPacketBuffer(PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
}, pkt,
); err != nil {
sent.Dropped.Increment()
@@ -778,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
// such a state forever, unless some other external event resolves the DAD
// process (receiving an NA from the true owner of addr, or an NS for addr
// (implying another node is attempting to use addr)). It is up to the caller
-// of this function to handle such a scenario. Normally, addr will be removed
-// from n right after this function returns or the address successfully
-// resolved.
+// of this function to handle such a scenario.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
dad, ok := ndp.dad[addr]
if !ok {
@@ -801,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- // Is the NIC configured to handle RAs at all?
+ // Is the IPv6 endpoint configured to handle RAs at all?
//
// Currently, the stack does not determine router interface status on a
- // per-interface basis; it is a stack-wide configuration, so we check
- // stack's forwarding flag to determine if the NIC is a routing
- // interface.
- if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) {
+ // per-interface basis; it is a protocol-wide configuration, so we check the
+ // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
+ // packets.
+ if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
return
}
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -839,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
- // Is the NIC configured to discover default routers?
+ // Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
rl := ra.RouterLifetime()
@@ -881,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -928,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// invalidateDefaultRouter invalidates a discovered default router.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
rtr, ok := ndp.defaultRouters[ip]
@@ -942,32 +964,32 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
// rememberDefaultRouter remembers a newly discovered default router with IPv6
// link-local address ip with lifetime rl.
//
-// The router identified by ip MUST NOT already be known by the NIC.
+// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
@@ -982,22 +1004,22 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The prefix identified by prefix MUST NOT already be known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
@@ -1011,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
s, ok := ndp.onLinkPrefixes[prefix]
@@ -1025,8 +1047,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1036,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the on-link flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
prefix := pi.Subnet()
prefixState, ok := ndp.onLinkPrefixes[prefix]
@@ -1089,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -1125,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -1142,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
}
- ndp.deprecateSLAACAddress(state.stableAddr.ref)
+ ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1189,7 +1211,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
// If the address is assigned (DAD resolved), generate a temporary address.
- if state.stableAddr.ref.getKind() == permanent {
+ if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
@@ -1198,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.slaacPrefixes[prefix] = state
}
-// addSLAACAddr adds a SLAAC address to the NIC.
+// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr,
- }
-
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
- panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
- return ref
+ return addressEndpoint
}
// generateSLAACAddr generates a SLAAC address for prefix.
@@ -1232,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo
//
// Panics if the prefix is not a SLAAC prefix or it already has an address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
- if r := state.stableAddr.ref; r != nil {
- panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix()))
}
// If we have already reached the maximum address generation attempts for the
@@ -1255,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1272,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
+ linkAddr := ndp.ep.nic.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
@@ -1291,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
PrefixLen: validPrefixLenForAutoGen,
}
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
state.stableAddr.localGenerationFailures++
}
- if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
- state.stableAddr.ref = ref
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
}
@@ -1309,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
//
-// If generating a new address for the prefix fails, the prefix will be
-// invalidated.
+// If generating a new address for the prefix fails, the prefix is invalidated.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1332,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
// generateTempSLAACAddr generates a new temporary SLAAC address.
//
-// If resetGenAttempts is true, the prefix's generation counter will be reset.
+// If resetGenAttempts is true, the prefix's generation counter is reset.
//
// Returns true if a new address was generated.
func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
@@ -1353,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- stableAddr := prefixState.stableAddr.ref.address()
+ stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
now := time.Now()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
@@ -1392,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- // Attempt to generate a new address that is not already assigned to the NIC.
+ // Attempt to generate a new address that is not already assigned to the IPv6
+ // endpoint.
var generatedAddr tcpip.AddressWithPrefix
for i := 0; ; i++ {
// If we were unable to generate an address after the maximum SLAAC address
@@ -1402,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
}
@@ -1410,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
// address with a zero preferred lifetime. The checks above ensure this
// so we know the address is not deprecated.
- ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
- if ref == nil {
+ addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */)
+ if addressEndpoint == nil {
return false
}
state := tempSLAACAddrState{
- deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1427,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
}
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1442,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1465,8 +1482,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
ndp.slaacPrefixes[prefix] = prefixState
}),
- createdAt: now,
- ref: ref,
+ createdAt: now,
+ addressEndpoint: addressEndpoint,
}
state.deprecationJob.Schedule(pl)
@@ -1481,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1496,14 +1513,14 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
// If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
if deprecated {
- ndp.deprecateSLAACAddress(prefixState.stableAddr.ref)
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint)
} else {
- prefixState.stableAddr.ref.deprecated = false
+ prefixState.stableAddr.addressEndpoint.SetDeprecated(false)
}
// If prefix was preferred for some finite lifetime before, cancel the
@@ -1565,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// If DAD is not yet complete on the stable address, there is no need to do
// work with temporary addresses.
- if prefixState.stableAddr.ref.getKind() != permanent {
+ if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent {
return
}
@@ -1608,9 +1625,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
newPreferredLifetime := preferredUntil.Sub(now)
tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
- tempAddrState.ref.deprecated = false
+ tempAddrState.addressEndpoint.SetDeprecated(false)
tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
@@ -1635,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// due to an update in preferred lifetime.
//
// If each temporay address has already been regenerated, no new temporary
- // address will be generated. To ensure continuation of temporary SLAAC
- // addresses, we manually try to regenerate an address here.
+ // address is generated. To ensure continuation of temporary SLAAC addresses,
+ // we manually try to regenerate an address here.
if len(regenForAddr) != 0 || allAddressesRegenerated {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
@@ -1647,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
}
-// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
-// dispatcher that ref has been deprecated.
+// deprecateSLAACAddress marks the address as deprecated and notifies the NDP
+// dispatcher that address has been deprecated.
//
-// deprecateSLAACAddress does nothing if ref is already deprecated.
+// deprecateSLAACAddress does nothing if the address is already deprecated.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
- if ref.deprecated {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) {
+ if addressEndpoint.Deprecated() {
return
}
- ref.deprecated = true
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
+ addressEndpoint.SetDeprecated(true)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
- if r := state.stableAddr.ref; r != nil {
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
// Since we are already invalidating the prefix, do not invalidate the
// prefix when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
-
- ndp.cleanupSLAACPrefixResources(prefix, state)
}
// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
// resources.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
state, ok := ndp.slaacPrefixes[prefix]
- if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() {
+ if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address {
return
}
if !invalidatePrefix {
// If the prefix is not being invalidated, disassociate the address from the
// prefix and do nothing further.
- state.stableAddr.ref = nil
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
ndp.slaacPrefixes[prefix] = state
return
}
@@ -1709,14 +1727,17 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
//
// Panics if the SLAAC prefix is not known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
// Invalidate all temporary addresses.
for tempAddr, tempAddrState := range state.tempAddrs {
ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
}
- state.stableAddr.ref = nil
+ if state.stableAddr.addressEndpoint != nil {
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ }
state.deprecationJob.Cancel()
state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
@@ -1724,12 +1745,12 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
// Since we are already invalidating the address, do not invalidate the
// address when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
@@ -1738,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
// SLAAC address's resources from ndp.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1765,35 +1786,29 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
// jobs and entry.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.addressEndpoint.DecRef()
+ tempAddrState.addressEndpoint = nil
tempAddrState.deprecationJob.Cancel()
tempAddrState.invalidationJob.Cancel()
tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
-// cleanupState cleans up ndp's state.
-//
-// If hostOnly is true, then only host-specific state will be cleaned up.
+// removeSLAACAddresses removes all SLAAC addresses.
//
-// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
-// transitioning from a host to a router. This function will invalidate all
-// discovered on-link prefixes, discovered routers, and auto-generated
-// addresses.
-//
-// If hostOnly is true, then the link-local auto-generated address will not be
-// invalidated as routers are also expected to generate a link-local address.
+// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalPrefixes := 0
+ var linkLocalPrefixes int
for prefix, state := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && prefix == linkLocalSubnet {
+ if keepLinkLocal && prefix == linkLocalSubnet {
linkLocalPrefixes++
continue
}
@@ -1804,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
}
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state is cleaned up.
+//
+// This function invalidates all discovered on-link prefixes, discovered
+// routers, and auto-generated addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address aren't
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1827,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
@@ -1848,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
- ndp.nic.mu.Lock()
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the NIC lock and stopped solicitations.
- // Simply return here and do nothing further.
- ndp.nic.mu.Unlock()
+ // goroutine already obtained the IPv6 endpoint lock and stopped
+ // solicitations. Simply return here and do nothing further.
+ ndp.ep.mu.Unlock()
return
}
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
- if ref == nil {
- ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ if addressEndpoint == nil {
+ // Incase this ends up creating a new temporary address, we need to hold
+ // onto the endpoint until a route is obtained. If we decrement the
+ // reference count before obtaing a route, the address's resources would
+ // be released and attempting to obtain a route after would fail. Once a
+ // route is obtainted, it is safe to decrement the reference count since
+ // obtaining a route increments the address's reference count.
+ addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
- localAddr := ref.address()
- r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ addressEndpoint.DecRef()
+ if err != nil {
+ return
+ }
defer r.Release()
// Route should resolve immediately since
@@ -1876,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() {
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is
+ // not locked, the IPv6 endpoint could have been disabled or removed by
+ // another goroutine.
if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1907,21 +1948,20 @@ func (ndp *ndpState) startSolicitingRouters() {
rs.Options().Serialize(optsSerializer)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- pkt := NewPacketBuffer(PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
}, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -1929,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.nic.mu.Lock()
+ ndp.ep.mu.Lock()
if done || remaining == 0 {
ndp.rtrSolicit.timer = nil
ndp.rtrSolicit.done = nil
} else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
- // we still reached this point, then we know the NIC
+ // we still reached this point, then we know the IPv6 endpoint
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
})
}
@@ -1949,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
if ndp.rtrSolicit.timer == nil {
// Nothing to do.
@@ -1965,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index c93d1194f..9033a9ed5 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"strings"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -65,10 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ t.Cleanup(ep.Close)
+
return s, ep
}
+var _ NDPDispatcher = (*testNDPDispatcher)(nil)
+
+// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
+type testNDPDispatcher struct {
+ addr tcpip.Address
+}
+
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+ t.addr = addr
+ return true
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+ t.addr = addr
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
+}
+
+func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
+ var ndpDisp testNDPDispatcher
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+
+ ipv6EP := ep.(*endpoint)
+ ipv6EP.mu.Lock()
+ ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.Unlock()
+
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+
+ ndpDisp.addr = ""
+ ndpEP := ep.(stack.NDPEndpoint)
+ ndpEP.InvalidateDefaultRouter(lladdr1)
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -325,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naDst tcpip.Address
}{
{
- name: "Unspecified source to multicast destination",
+ name: "Unspecified source to solicited-node multicast destination",
nsOpts: nil,
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
@@ -352,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
+ nsInvalid: true,
},
{
name: "Unspecified source with source ll option to unicast destination",
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index c9e57dc0d..d0ffc299a 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -8,6 +8,7 @@ go_library(
"testutil.go",
],
visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 7f1d79115..eba97334e 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -54,9 +54,8 @@ go_template_instance(
go_library(
name = "stack",
srcs = [
+ "addressable_endpoint_state.go",
"conntrack.go",
- "dhcpv6configurationfromndpra_string.go",
- "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -65,7 +64,6 @@ go_library(
"iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
- "ndp.go",
"neighbor_cache.go",
"neighbor_entry.go",
"neighbor_entry_list.go",
@@ -74,6 +72,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -106,6 +105,7 @@ go_test(
name = "stack_x_test",
size = "medium",
srcs = [
+ "addressable_endpoint_state_test.go",
"ndp_test.go",
"nud_test.go",
"stack_test.go",
@@ -116,13 +116,13 @@ go_test(
deps = [
":stack",
"//pkg/rand",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -138,7 +138,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarder_test.go",
+ "forwarding_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
new file mode 100644
index 000000000..4d3acab96
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -0,0 +1,753 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
+var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
+
+// AddressableEndpointState is an implementation of an AddressableEndpoint.
+type AddressableEndpointState struct {
+ networkEndpoint NetworkEndpoint
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ endpoints map[tcpip.Address]*addressState
+ primary []*addressState
+
+ // groups holds the mapping between group addresses and the number of times
+ // they have been joined.
+ groups map[tcpip.Address]uint32
+ }
+}
+
+// Init initializes the AddressableEndpointState with networkEndpoint.
+//
+// Must be called before calling any other function on m.
+func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
+ a.networkEndpoint = networkEndpoint
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.endpoints = make(map[tcpip.Address]*addressState)
+ a.mu.groups = make(map[tcpip.Address]uint32)
+}
+
+// ReadOnlyAddressableEndpointState provides read-only access to an
+// AddressableEndpointState.
+type ReadOnlyAddressableEndpointState struct {
+ inner *AddressableEndpointState
+}
+
+// AddrOrMatching returns an endpoint for the passed address that is consisdered
+// bound to the wrapped AddressableEndpointState.
+//
+// If addr is an exact match with an existing address, that address is returned.
+// Otherwise, f is called with each address and the address that f returns true
+// for is returned.
+//
+// Returns nil of no address matches.
+func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ if ep, ok := m.inner.mu.endpoints[addr]; ok {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ for _, ep := range m.inner.mu.endpoints {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ return nil
+}
+
+// Lookup returns the AddressEndpoint for the passed address.
+//
+// Returns nil if the passed address is not associated with the
+// AddressableEndpointState.
+func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ ep, ok := m.inner.mu.endpoints[addr]
+ if !ok {
+ return nil
+ }
+ return ep
+}
+
+// ForEach calls f for each address pair.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ for _, ep := range m.inner.mu.endpoints {
+ if !f(ep) {
+ return
+ }
+ }
+}
+
+// ForEachPrimaryEndpoint calls f for each primary address.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+ for _, ep := range m.inner.mu.primary {
+ f(ep)
+ }
+}
+
+// ReadOnly returns a readonly reference to a.
+func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
+ return ReadOnlyAddressableEndpointState{inner: a}
+}
+
+func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.releaseAddressStateLocked(addrState)
+}
+
+// releaseAddressState removes addrState from s's address state (primary and endpoints list).
+//
+// Preconditions: a.mu must be write locked.
+func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) {
+ oldPrimary := a.mu.primary
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ oldPrimary[len(oldPrimary)-1] = nil
+ break
+ }
+ }
+ delete(a.mu.endpoints, addrState.addr.Address)
+}
+
+// AddAndAcquirePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// AddAndAcquireTemporaryAddress adds a temporary address.
+//
+// Returns tcpip.ErrDuplicateAddress if the address exists.
+//
+// The temporary address's endpoint is acquired and returned.
+func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// addAndAcquireAddressLocked adds, acquires and returns a permanent or
+// temporary address.
+//
+// If the addressable endpoint already has the address in a non-permanent state,
+// and addAndAcquireAddressLocked is adding a permanent address, that address is
+// promoted in place and its properties set to the properties provided. If the
+// address already exists in any other state, then tcpip.ErrDuplicateAddress is
+// returned, regardless the kind of address that is being added.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) {
+ // attemptAddToPrimary is false when the address is already in the primary
+ // address list.
+ attemptAddToPrimary := true
+ addrState, ok := a.mu.endpoints[addr.Address]
+ if ok {
+ if !permanent {
+ // We are adding a non-permanent address but the address exists. No need
+ // to go any further since we can only promote existing temporary/expired
+ // addresses to permanent.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ addrState.mu.Lock()
+ if addrState.mu.kind.IsPermanent() {
+ addrState.mu.Unlock()
+ // We are adding a permanent address but a permanent address already
+ // exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr))
+ }
+
+ // We now promote the address.
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ // The address is already in the primary address list.
+ attemptAddToPrimary = false
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ // The address is already first in the primary address list.
+ attemptAddToPrimary = false
+ } else {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ }
+ case NeverPrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ break
+ }
+ }
+ }
+
+ if addrState == nil {
+ addrState = &addressState{
+ addressableEndpointState: a,
+ addr: addr,
+ }
+ a.mu.endpoints[addr.Address] = addrState
+ addrState.mu.Lock()
+ // We never promote an address to temporary - it can only be added as such.
+ // If we are actaully adding a permanent address, it is promoted below.
+ addrState.mu.kind = Temporary
+ }
+
+ // At this point we have an address we are either promoting from an expired or
+ // temporary address to permanent, promoting an expired address to temporary,
+ // or we are adding a new temporary or permanent address.
+ //
+ // The address MUST be write locked at this point.
+ defer addrState.mu.Unlock()
+
+ if permanent {
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr))
+ }
+
+ // Primary addresses are biased by 1.
+ addrState.mu.refs++
+ addrState.mu.kind = Permanent
+ }
+ // Acquire the address before returning it.
+ addrState.mu.refs++
+ addrState.mu.deprecated = deprecated
+ addrState.mu.configType = configType
+
+ if attemptAddToPrimary {
+ switch peb {
+ case NeverPrimaryEndpoint:
+ case CanBePrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary, addrState)
+ case FirstPrimaryEndpoint:
+ if cap(a.mu.primary) == len(a.mu.primary) {
+ a.mu.primary = append([]*addressState{addrState}, a.mu.primary...)
+ } else {
+ // Shift all the endpoints by 1 to make room for the new address at the
+ // front. We could have just created a new slice but this saves
+ // allocations when the slice has capacity for the new address.
+ primaryCount := len(a.mu.primary)
+ a.mu.primary = append(a.mu.primary, nil)
+ if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount {
+ panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount))
+ }
+ a.mu.primary[0] = addrState
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ }
+
+ return addrState, nil
+}
+
+// RemovePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if _, ok := a.mu.groups[addr]; ok {
+ panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
+ }
+
+ return a.removePermanentAddressLocked(addr)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ addrState, ok := a.mu.endpoints[addr]
+ if !ok {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// RemovePermanentEndpoint removes the passed endpoint if it is associated with
+// a and permanent.
+func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error {
+ addrState, ok := ep.(*addressState)
+ if !ok || addrState.addressableEndpointState != a {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error {
+ if !addrState.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ addrState.SetKind(PermanentExpired)
+ a.decAddressRefLocked(addrState)
+ return nil
+}
+
+// decAddressRef decrements the address's reference count and releases it once
+// the reference count hits 0.
+func (a *AddressableEndpointState) decAddressRef(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.decAddressRefLocked(addrState)
+}
+
+// decAddressRefLocked is like decAddressRef but with locking requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) {
+ addrState.mu.Lock()
+ defer addrState.mu.Unlock()
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr))
+ }
+
+ addrState.mu.refs--
+
+ if addrState.mu.refs != 0 {
+ return
+ }
+
+ // A non-expired permanent address must not have its reference count dropped
+ // to 0.
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind))
+ }
+
+ a.releaseAddressStateLocked(addrState)
+}
+
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if addrState, ok := a.mu.endpoints[localAddr]; ok {
+ if !addrState.IsAssigned(allowTemp) {
+ return nil
+ }
+
+ if !addrState.IncRef() {
+ panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ }
+
+ return addrState
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ if err != nil {
+ // addAndAcquireAddressLocked only returns an error if the address is
+ // already assigned but we just checked above if the address exists so we
+ // expect no error.
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ }
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+ return ep
+}
+
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
+
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+
+ return ep
+}
+
+// PrimaryAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.primary {
+ // Don't include tentative, expired or temporary endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
+ switch ep.GetKind() {
+ case PermanentTentative, PermanentExpired, Temporary:
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// PermanentAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.endpoints {
+ if !ep.GetKind().IsPermanent() {
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// JoinGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
+ if err != nil {
+ return false, err
+ }
+ // We have no need for the address endpoint.
+ a.decAddressRefLocked(ep)
+ }
+
+ a.mu.groups[group] = joins + 1
+ return !ok, nil
+}
+
+// LeaveGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ return false, tcpip.ErrBadLocalAddress
+ }
+
+ if joins == 1 {
+ a.removeGroupAddressLocked(group)
+ delete(a.mu.groups, group)
+ return true, nil
+ }
+
+ a.mu.groups[group] = joins - 1
+ return false, nil
+}
+
+// IsInGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ _, ok := a.mu.groups[group]
+ return ok
+}
+
+func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
+ if err := a.removePermanentAddressLocked(group); err != nil {
+ // removePermanentEndpointLocked would only return an error if group is
+ // not bound to the addressable endpoint, but we know it MUST be assigned
+ // since we have group in our map of groups.
+ panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
+ }
+}
+
+// Cleanup forcefully leaves all groups and removes all permanent addresses.
+func (a *AddressableEndpointState) Cleanup() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ for group := range a.mu.groups {
+ a.removeGroupAddressLocked(group)
+ }
+ a.mu.groups = make(map[tcpip.Address]uint32)
+
+ for _, ep := range a.mu.endpoints {
+ // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
+ // not a permanent address.
+ if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
+ }
+ }
+}
+
+var _ AddressEndpoint = (*addressState)(nil)
+
+// addressState holds state for an address.
+type addressState struct {
+ addressableEndpointState *AddressableEndpointState
+ addr tcpip.AddressWithPrefix
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ refs uint32
+ kind AddressKind
+ configType AddressConfigType
+ deprecated bool
+ }
+}
+
+// AddressWithPrefix implements AddressEndpoint.
+func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
+ return a.addr
+}
+
+// GetKind implements AddressEndpoint.
+func (a *addressState) GetKind() AddressKind {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.kind
+}
+
+// SetKind implements AddressEndpoint.
+func (a *addressState) SetKind(kind AddressKind) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.kind = kind
+}
+
+// IsAssigned implements AddressEndpoint.
+func (a *addressState) IsAssigned(allowExpired bool) bool {
+ if !a.addressableEndpointState.networkEndpoint.Enabled() {
+ return false
+ }
+
+ switch a.GetKind() {
+ case PermanentTentative:
+ return false
+ case PermanentExpired:
+ return allowExpired
+ default:
+ return true
+ }
+}
+
+// IncRef implements AddressEndpoint.
+func (a *addressState) IncRef() bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.mu.refs == 0 {
+ return false
+ }
+
+ a.mu.refs++
+ return true
+}
+
+// DecRef implements AddressEndpoint.
+func (a *addressState) DecRef() {
+ a.addressableEndpointState.decAddressRef(a)
+}
+
+// ConfigType implements AddressEndpoint.
+func (a *addressState) ConfigType() AddressConfigType {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.configType
+}
+
+// SetDeprecated implements AddressEndpoint.
+func (a *addressState) SetDeprecated(d bool) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.deprecated = d
+}
+
+// Deprecated implements AddressEndpoint.
+func (a *addressState) Deprecated() bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.deprecated
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
new file mode 100644
index 000000000..26787d0a3
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
+// endpoint state removes permanent addresses and leaves groups.
+func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ var s stack.AddressableEndpointState
+ s.Init(&ep)
+
+ addr := tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ }
+
+ {
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ }
+ // We don't need the address endpoint.
+ ep.DecRef()
+ }
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep == nil {
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
+ }
+ ep.DecRef()
+ }
+
+ group := tcpip.Address("\x02")
+ if added, err := s.JoinGroup(group); err != nil {
+ t.Fatalf("s.JoinGroup(%s): %s", group, err)
+ } else if !added {
+ t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
+ }
+ if !s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
+ }
+
+ s.Cleanup()
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ }
+ }
+ if s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ }
+}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 836682ea0..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go
index 8d18f3c8c..cf042309e 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,20 +46,27 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
- nicID tcpip.NICID
+ AddressableEndpointState
+
+ nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
- ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
-func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+ return nil
}
-func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (*fwdTestNetworkEndpoint) Enabled() bool {
+ return true
+}
+
+func (*fwdTestNetworkEndpoint) Disable() {}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -71,17 +79,13 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
-func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -94,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -106,7 +110,9 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu
return tcpip.ErrNotSupported
}
-func (*fwdTestNetworkEndpoint) Close() {}
+func (f *fwdTestNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
@@ -116,6 +122,11 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
@@ -145,13 +156,14 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
- return &fwdTestNetworkEndpoint{
- nicID: nicID,
+func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
+ e := &fwdTestNetworkEndpoint{
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: ep,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
@@ -186,6 +198,21 @@ func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 4a521eca9..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -60,11 +60,11 @@ func DefaultTables() *IPTables {
v4Tables: [numTables]Table{
natID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -83,9 +83,9 @@ func DefaultTables() *IPTables {
},
mangleID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -101,10 +101,10 @@ func DefaultTables() *IPTables {
},
filterID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -125,11 +125,11 @@ func DefaultTables() *IPTables {
v6Tables: [numTables]Table{
natID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -148,9 +148,9 @@ func DefaultTables() *IPTables {
},
mangleID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -166,10 +166,10 @@ func DefaultTables() *IPTables {
},
filterID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, tcpip.ErrNotConnected
}
- return it.connections.originalDst(epID)
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 5f1b2af64..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -21,78 +21,139 @@ import (
)
// AcceptTarget accepts packets.
-type AcceptTarget struct{}
+type AcceptTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (at *AcceptTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: at.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
// DropTarget drops packets.
-type DropTarget struct{}
+type DropTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (dt *DropTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: dt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
-type ErrorTarget struct{}
+type ErrorTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (et *ErrorTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: et.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
type UserChainTarget struct {
+ // Name is the chain name.
Name string
+
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (uc *UserChainTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: uc.NetworkProtocol,
+ }
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
// ReturnTarget returns from the current chain. If the chain is a built-in, the
// hook's underflow should be called.
-type ReturnTarget struct{}
+type ReturnTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (rt *ReturnTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: rt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const RedirectTargetName = "REDIRECT"
+
// RedirectTarget redirects the packet by modifying the destination port/IP.
-// Min and Max values for IP and Ports in the struct indicate the range of
-// values which can be used to redirect.
+// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
+// them.
type RedirectTarget struct {
- // TODO(gvisor.dev/issue/170): Other flags need to be added after
- // we support them.
- // RangeProtoSpecified flag indicates single port is specified to
- // redirect.
- RangeProtoSpecified bool
+ // Addr indicates address used to redirect.
+ Addr tcpip.Address
- // MinIP indicates address used to redirect.
- MinIP tcpip.Address
+ // Port indicates port used to redirect.
+ Port uint16
- // MaxIP indicates address used to redirect.
- MaxIP tcpip.Address
-
- // MinPort indicates port used to redirect.
- MinPort uint16
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
- // MaxPort indicates port used to redirect.
- MaxPort uint16
+// ID implements Target.ID.
+func (rt *RedirectTarget) ID() TargetID {
+ return TargetID{
+ Name: RedirectTargetName,
+ NetworkProtocol: rt.NetworkProtocol,
+ }
}
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -103,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
- rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
- rt.MinIP = address
- rt.MaxIP = address
+ rt.Addr = address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.MinPort)
+ udpHeader.SetDestinationPort(rt.Port)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -139,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.MinIP)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 093ee6881..7b3f3e88b 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -104,8 +104,20 @@ type IPTables struct {
reaperDone chan struct{}
}
-// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules.
+// A Table defines a set of chains and hooks into the network stack.
+//
+// It is a list of Rules, entry points (BuiltinChains), and error handlers
+// (Underflows). As packets traverse netstack, they hit hooks. When a packet
+// hits a hook, iptables compares it to Rules starting from that hook's entry
+// point. So if a packet hits the Input hook, we look up the corresponding
+// entry point in BuiltinChains and jump to that point.
+//
+// If the Rule doesn't match the packet, iptables continues to the next Rule.
+// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept
+// or RuleDrop) that causes the packet to stop traversing iptables. It can also
+// jump to other rules or perform custom actions based on Rule.Target.
+//
+// Underflow Rules are invoked when a chain returns without reaching a verdict.
//
// +stateify savable
type Table struct {
@@ -260,6 +272,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return true
}
+// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
+// applies.
+func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber {
+ switch len(fl.Src) {
+ case header.IPv4AddressSize:
+ return header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src))
+}
+
// filterAddress returns whether addr matches the filter.
func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
matches := true
@@ -285,8 +309,23 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
+// A TargetID uniquely identifies a target.
+type TargetID struct {
+ // Name is the target name as stored in the xt_entry_target struct.
+ Name string
+
+ // NetworkProtocol is the protocol to which the target applies.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+
+ // Revision is the version of the target.
+ Revision uint8
+}
+
// A Target is the interface for taking an action for a packet.
type Target interface {
+ // ID uniquely identifies the Target.
+ ID() TargetID
+
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 8416dbcdb..73a01c2dd 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -150,10 +150,10 @@ type ndpDNSSLEvent struct {
type ndpDHCPv6Event struct {
nicID tcpip.NICID
- configuration stack.DHCPv6ConfigurationFromNDPRA
+ configuration ipv6.DHCPv6ConfigurationFromNDPRA
}
-var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
@@ -170,7 +170,7 @@ type ndpDispatcher struct {
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
-// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
@@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add
}
}
-// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.
return n.rememberRouter
}
-// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip
}
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
return n.rememberPrefix
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi
}
}
-// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption.
func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
if c := n.rdnssC; c != nil {
c <- ndpRDNSSEvent{
@@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
-// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+// Implements ipv6.NDPDispatcher.OnDNSSearchListOption.
func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
if n.dnsslC != nil {
n.dnsslC <- ndpDNSSLEvent{
@@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s
}
}
-// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
-func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
if c := n.dhcpv6ConfigurationC; c != nil {
c <- ndpDHCPv6Event{
nicID,
@@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPDisp: &ndpDisp,
- }
-
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = test.retransTimer
- opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
e := channelLinkWithHeaderLength{
Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ipv6.NDPConfigurations{
+ RetransmitTimer: test.retransTimer,
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ },
+ })},
+ })
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -558,6 +559,26 @@ func TestDADResolve(t *testing.T) {
}
}
+func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
+}
+
// TestDADFail tests to make sure that the DAD process fails if another node is
// detected to be performing DAD on the same address (receive an NS message from
// a node doing DAD for the same address), or if another node is detected to own
@@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) {
tests := []struct {
name string
- makeBuf func(tgt tcpip.Address) buffer.Prependable
+ rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "RxSolicit",
- func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(tgt)
- snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
- })
-
- return hdr
-
- },
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RxSolicit",
+ rxPkt: rxNDPSolicit,
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborSolicit
},
},
{
- "RxAdvert",
- func(tgt tcpip.Address) buffer.Prependable {
+ name: "RxAdvert",
+ rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
@@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) {
SrcAddr: tgt,
DstAddr: header.IPv6AllNodesMulticastAddress,
})
-
- return hdr
-
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborAdvert
},
},
@@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = time.Second * 2
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
+ ndpConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -664,13 +663,8 @@ func TestDADFail(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Receive a packet to simulate multiple nodes owning or
- // attempting to own the same address.
- hdr := test.makeBuf(addr1)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(header.IPv6ProtocolNumber, pkt)
+ // Receive a packet to simulate an address conflict.
+ test.rxPkt(e, addr1)
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
@@ -754,18 +748,19 @@ func TestDADStop(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.NDPConfigurations{
+
+ ndpConfigs := ipv6.NDPConfigurations{
RetransmitTimer: time.Second,
DupAddrDetectTransmits: 2,
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -815,19 +810,6 @@ func TestDADStop(t *testing.T) {
}
}
-// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
-// we attempt to update NDP configurations using an invalid NICID.
-func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- })
-
- // No NIC with ID 1 yet.
- if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
- }
-}
-
// TestSetNDPConfigurations tests that we can update and use per-interface NDP
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
@@ -863,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
})
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
@@ -892,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Update the NDP configurations on NIC(1) to use DAD.
- configs := stack.NDPConfigurations{
+ configs := ipv6.NDPConfigurations{
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
RetransmitTimer: test.retransmitTimer,
}
- if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(configs)
}
// Created after updating NIC(1)'s NDP configurations
@@ -1113,12 +1099,13 @@ func TestNoRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
s.SetForwarding(ipv6.ProtocolNumber, forwarding)
@@ -1151,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1192,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func(addr tcpip.Address, discovered bool) {
@@ -1285,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) {
}
// TestRouterDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
@@ -1293,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1306,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
- if i <= stack.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredDefaultRouters {
select {
case e := <-ndpDisp.routerC:
if diff := checkRouterEvent(e, llAddr, true); diff != "" {
@@ -1358,12 +1348,13 @@ func TestNoPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
s.SetForwarding(ipv6.ProtocolNumber, forwarding)
@@ -1399,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1445,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1545,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1621,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// TestPrefixDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
}
- optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
- prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+ optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
// Receive an RA with 2 more than the max number of discovered on-link
// prefixes.
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
prefixAddr[7] = byte(i)
prefix := tcpip.AddressWithPrefix{
@@ -1665,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
- if i < stack.MaxDiscoveredOnLinkPrefixes {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < ipv6.MaxDiscoveredOnLinkPrefixes {
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
@@ -1716,12 +1711,13 @@ func TestNoAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
s.SetForwarding(ipv6.ProtocolNumber, forwarding)
@@ -1749,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr(t *testing.T) {
+func TestAutoGenAddr2(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1766,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1876,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MaxDesyncFactor = time.Nanosecond
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1931,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2119,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
tests := []struct {
name string
@@ -2160,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2211,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = 0
+ ipv6.MaxDesyncFactor = 0
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2228,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2294,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2317,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2382,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Wait for all the temporary addresses to get invalidated.
@@ -2439,17 +2443,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2462,16 +2466,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2545,9 +2550,12 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// as paased.
ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2565,9 +2573,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
- }
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
@@ -2655,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: test.nicNameFromID,
+ },
+ })},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: test.nicNameFromID,
- },
})
s.SetRouteTable([]tcpip.Route{{
@@ -2739,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
ndpDisp.dadC = make(chan ndpDADEvent, 2)
ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
ndpConfigs.RetransmitTimer = retransmitTimer
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Do SLAAC for prefix.
@@ -2754,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// DAD failure to restart the local generation process.
addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
expectAutoGenAddrAsyncEvent(addr, newAddr)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
@@ -2794,14 +2802,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: ndpDisp,
- UseNeighborCache: useNeighborCache,
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -3036,11 +3045,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -3258,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
const minVLSeconds = 1
savedIL := header.NDPInfiniteLifetime
- savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
header.NDPInfiniteLifetime = savedIL
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3307,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3357,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
const newMinVL = 4
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3449,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
}
e := channel.New(10, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3515,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3700,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3781,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -3856,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const lifetimeSeconds = 10
// Needed for the temporary address sub test.
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
@@ -3938,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -3963,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
},
@@ -3977,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "Temporary address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -4029,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
},
- SecretKey: secretKey,
- },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4059,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
expectDADEvent(t, &ndpDisp, addr.Address, false)
@@ -4119,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
subnet tcpip.Subnet
triggerSLAACFn func(e *channel.Endpoint)
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -4142,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
AutoGenAddressConflictRetries: maxRetries,
@@ -4165,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4198,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4250,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4296,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
// Simulate a DAD conflict after some time has passed.
time.Sleep(failureTimer)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4459,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -4509,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4694,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func() (bool, ndpRouterEvent) {
@@ -4967,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
t.Helper()
select {
case e := <-ndpDisp.dhcpv6ConfigurationC:
@@ -5002,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Even if the first RA reports no DHCPv6 configurations are available, the
// dispatcher should get an event.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
// Receiving the same update again should not result in an event to the
// dispatcher.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
@@ -5011,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to Managed Address.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to none.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
expectNoDHCPv6Event()
@@ -5031,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
//
// Note, when the M flag is set, the O flag is redundant.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
expectNoDHCPv6Event()
// Even though the DHCPv6 flags are different, the effective configuration is
@@ -5044,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
@@ -5059,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
}
@@ -5217,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) {
}
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -5357,12 +5373,13 @@ func TestStopStartSolicitingRouters(t *testing.T) {
checker.NDPRS())
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 27e1feec0..4df288798 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
- case Reachable, Static:
+ case Stale:
+ entry.handlePacketQueuedLocked()
+ fallthrough
+ case Reachable, Static, Delay, Probe:
+ // As per RFC 4861 section 7.3.3:
+ // "Neighbor Unreachability Detection operates in parallel with the sending
+ // of packets to a neighbor. While reasserting a neighbor's reachability,
+ // a node continues sending packets to that neighbor using the cached
+ // link-layer address."
return entry.neigh, nil, nil
-
- case Unknown, Incomplete, Stale, Delay, Probe:
+ case Unknown, Incomplete:
entry.addWakerLocked(w)
if entry.done == nil {
@@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.handlePacketQueuedLocked()
return entry.neigh, entry.done, tcpip.ErrWouldBlock
-
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index a0b7da5cd..fcd54ed83 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) {
}
// Verify the entry exists
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
- }
- if t.Failed() {
- t.FailNow()
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ {
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
}
// Notify of a link address change
@@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) {
IsRouter: false,
})
- // Requesting the entry again should start address resolution
+ // Requesting the entry again should start neighbor reachability confirmation.
+ //
+ // Verify the entry's new link address and the new state.
{
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
- select {
- case <-doneCh:
- default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
- // Verify the entry's new link address
+ // Verify that the neighbor is now reachable.
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- want = NeighborEntry{
+ want := NeighborEntry{
Addr: entry.Addr,
LocalAddr: entry.LocalAddr,
LinkAddr: updatedLinkAddr,
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 213646160..4d69a4de1 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// NeighborEntry describes a neighboring device in the local network.
@@ -235,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -276,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -439,7 +440,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.notifyWakersLocked()
}
- if e.isRouter && !flags.IsRouter {
+ if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
// "In those cases where the IsRouter flag changes from TRUE to FALSE as
// a result of this update, the node MUST remove that router from the
// Default Router List and update the Destination Cache entries for all
@@ -447,9 +448,17 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// 7.3.3. This is needed to detect when a node that is used as a router
// stops forwarding packets due to being configured as a host."
// - RFC 4861 section 7.2.5
- e.nic.mu.Lock()
- e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr)
- e.nic.mu.Unlock()
+ //
+ // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
+ // here.
+ ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
+ }
+
+ if ndpEP, ok := ep.(NDPEndpoint); ok {
+ ndpEP.InvalidateDefaultRouter(e.neigh.Addr)
+ }
}
e.isRouter = flags.IsRouter
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index e530ec7ea..e79abebca 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -226,25 +227,23 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- id: entryTestNICID,
- linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+
+ id: entryTestNICID,
stack: &Stack{
clock: clock,
nudDisp: &disp,
},
}
+ nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
+ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
+ }
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
- // Stub out ndpState to verify modification of default routers.
- nic.mu.ndp = ndpState{
- nic: &nic,
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- }
-
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
nic: &nic,
@@ -817,6 +816,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
+ ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
+
e.mu.Lock()
e.handlePacketQueuedLocked()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
@@ -830,9 +831,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
}
- e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{
- invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}),
- }
+
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -841,8 +840,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
if got, want := e.isRouter, false; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
}
- if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok {
- t.Errorf("unexpected defaultRouter for %s", entryTestAddr1)
+ if ipv6EP.invalidatedRtr != e.neigh.Addr {
+ t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr)
}
e.mu.Unlock()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 2875a5b60..8828cc5fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math/rand"
"reflect"
- "sort"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sleep"
@@ -28,39 +27,37 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var ipv4BroadcastAddr = tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 8 * header.IPv4AddressSize,
- },
-}
+var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
+ LinkEndpoint
+
stack *Stack
id tcpip.NICID
name string
- linkEP LinkEndpoint
context NICContext
- stats NICStats
- neigh *neighborCache
+ stats NICStats
+ neigh *neighborCache
+
+ // The network endpoints themselves may be modified by calling the interface's
+ // methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
mu struct {
sync.RWMutex
- enabled bool
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
- ndp ndpState
}
}
@@ -84,25 +81,6 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
-// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
-type PrimaryEndpointBehavior int
-
-const (
- // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
- CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
-
- // FirstPrimaryEndpoint indicates the endpoint should be the first
- // primary endpoint considered. If there are multiple endpoints with
- // this behavior, the most recently-added one will be first.
- FirstPrimaryEndpoint
-
- // NeverPrimaryEndpoint indicates the endpoint should never be a
- // primary endpoint.
- NeverPrimaryEndpoint
-)
-
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -114,27 +92,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
+ LinkEndpoint: ep,
+
stack: stack,
id: id,
name: name,
- linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
- nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
- nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- nic.mu.ndp.initializeTempAddrState()
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -162,37 +129,40 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.linkEP.Attach(nic)
+ nic.LinkEndpoint.Attach(nic)
return nic
}
-// enabled returns true if n is enabled.
-func (n *NIC) enabled() bool {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- return enabled
+func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+ return n.networkEndpoints[proto]
}
-// disable disables n.
+// Enabled implements NetworkInterface.
+func (n *NIC) Enabled() bool {
+ return atomic.LoadUint32(&n.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the NIC.
//
-// It undoes the work done by enable.
-func (n *NIC) disable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if !enabled {
- return nil
+// Returns true if the enabled status was updated.
+func (n *NIC) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&n.enabled, 1) == 0
}
+ return atomic.SwapUint32(&n.enabled, 0) == 1
+}
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() {
n.mu.Lock()
- err := n.disableLocked()
+ n.disableLocked()
n.mu.Unlock()
- return err
}
// disableLocked disables n.
@@ -200,9 +170,9 @@ func (n *NIC) disable() *tcpip.Error {
// It undoes the work done by enable.
//
// n MUST be locked.
-func (n *NIC) disableLocked() *tcpip.Error {
- if !n.mu.enabled {
- return nil
+func (n *NIC) disableLocked() {
+ if !n.setEnabled(false) {
+ return
}
// TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
@@ -210,38 +180,9 @@ func (n *NIC) disableLocked() *tcpip.Error {
// again, and applications may not know that the underlying NIC was ever
// disabled.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- n.mu.ndp.stopSolicitingRouters()
- n.mu.ndp.cleanupState(false /* hostOnly */)
-
- // Stop DAD for all the unicast IPv6 endpoints that are in the
- // permanentTentative state.
- for _, r := range n.mu.endpoints {
- if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
- n.mu.ndp.stopDuplicateAddressDetection(addr)
- }
- }
-
- // The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
- }
-
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- // The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
-
- // The address may have already been removed.
- if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
+ for _, ep := range n.networkEndpoints {
+ ep.Disable()
}
-
- n.mu.enabled = false
- return nil
}
// enable enables n.
@@ -251,162 +192,38 @@ func (n *NIC) disableLocked() *tcpip.Error {
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if enabled {
- return nil
- }
-
n.mu.Lock()
defer n.mu.Unlock()
- if n.mu.enabled {
+ if !n.setEnabled(true) {
return nil
}
- n.mu.enabled = true
-
- // Create an endpoint to receive broadcast packets on this interface.
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
-
- // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
- // multicast group. Note, the IANA calls the all-hosts multicast group the
- // all-systems multicast group.
- if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil {
- return err
- }
- }
-
- // Join the IPv6 All-Nodes Multicast group if the stack is configured to
- // use IPv6. This is required to ensure that this node properly receives
- // and responds to the various NDP messages that are destined to the
- // all-nodes multicast address. An example is the Neighbor Advertisement
- // when we perform Duplicate Address Detection, or Router Advertisement
- // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
- // section 4.2 for more information.
- //
- // Also auto-generate an IPv6 link-local address based on the NIC's
- // link address if it is configured to do so. Note, each interface is
- // required to have IPv6 link-local unicast address, as per RFC 4291
- // section 2.1.
- _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
- if !ok {
- return nil
- }
-
- // Join the All-Nodes multicast group before starting DAD as responses to DAD
- // (NDP NS) messages may be sent to the All-Nodes multicast group if the
- // source address of the NDP NS is the unspecified address, as per RFC 4861
- // section 7.2.4.
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
- // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
- // state.
- //
- // Addresses may have aleady completed DAD but in the time since the NIC was
- // last enabled, other devices may have acquired the same addresses.
- for _, r := range n.mu.endpoints {
- addr := r.address()
- if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
- continue
- }
-
- r.setKind(permanentTentative)
- if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.Enable(); err != nil {
return err
}
}
- // Do not auto-generate an IPv6 link-local address for loopback devices.
- if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
- // The valid and preferred lifetime is infinite for the auto-generated
- // link-local address.
- n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
- }
-
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyways.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !n.stack.Forwarding(header.IPv6ProtocolNumber) {
- n.mu.ndp.startSolicitingRouters()
- }
-
return nil
}
-// remove detaches NIC from the link endpoint, and marks existing referenced
-// network endpoints expired. This guarantees no packets between this NIC and
-// the network stack.
+// remove detaches NIC from the link endpoint and releases network endpoint
+// resources. This guarantees no packets between this NIC and the network
+// stack.
func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
n.disableLocked()
- // TODO(b/151378115): come up with a better way to pick an error than the
- // first one.
- var err *tcpip.Error
-
- // Forcefully leave multicast groups.
- for nid := range n.mu.mcastJoins {
- if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
- err = tempErr
- }
- }
-
- // Remove permanent and permanentTentative addresses, so no packet goes out.
- for nid, ref := range n.mu.endpoints {
- switch ref.getKind() {
- case permanentTentative, permanent:
- if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
- err = tempErr
- }
- }
- }
-
- // Release any resources the network endpoint may hold.
for _, ep := range n.networkEndpoints {
ep.Close()
}
// Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
-
- return err
-}
-
-// becomeIPv6Router transitions n into an IPv6 router.
-//
-// When transitioning into an IPv6 router, host-only state (NDP discovered
-// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated and NDP router solicitations will be stopped.
-func (n *NIC) becomeIPv6Router() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.cleanupState(true /* hostOnly */)
- n.mu.ndp.stopSolicitingRouters()
-}
-
-// becomeIPv6Host transitions n into an IPv6 host.
-//
-// When transitioning into an IPv6 host, NDP router solicitations will be
-// started.
-func (n *NIC) becomeIPv6Host() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.startSolicitingRouters()
+ n.LinkEndpoint.Attach(nil)
+ return nil
}
// setPromiscuousMode enables or disables promiscuous mode.
@@ -423,217 +240,113 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
-func (n *NIC) isLoopback() bool {
- return n.linkEP.Capabilities()&CapabilityLoopback != 0
+// IsLoopback implements NetworkInterface.
+func (n *NIC) IsLoopback() bool {
+ return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
}
-// setSpoofing enables or disables address spoofing.
-func (n *NIC) setSpoofing(enable bool) {
- n.mu.Lock()
- n.mu.spoofing = enable
- n.mu.Unlock()
-}
-
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 {
- return n.primaryIPv6Endpoint(remoteAddr)
- }
-
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, r := range n.mu.primary[protocol] {
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- if !r.deprecated {
- if r.tryIncRef() {
- // r is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- deprecatedEndpoint.decRefLocked()
- deprecatedEndpoint = nil
- }
-
- return r
- }
- } else if deprecatedEndpoint == nil && r.tryIncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of r in
- // case n doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, r's reference count
- // will be decremented when such an endpoint is found.
- deprecatedEndpoint = r
+// WritePacket implements NetworkLinkEndpoint.
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 5.2 (for IPv6):
+ // Once the IP address of the next-hop node is known, the sender
+ // examines the Neighbor Cache for link-layer information about that
+ // neighbor. If no entry exists, the sender creates one, sets its state
+ // to INCOMPLETE, initiates Address Resolution, and then queues the data
+ // packet pending completion of address resolution.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ r := r.Clone()
+ n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ return nil
}
+ return err
}
- // n doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
- // endpoints either).
- return deprecatedEndpoint
+ return n.writePacket(r, gso, protocol, pkt)
}
-// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
-// 6724 section 5).
-type ipv6AddrCandidate struct {
- ref *referencedNetworkEndpoint
- scope header.IPv6AddressScope
-}
+func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
-// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- n.mu.RLock()
- ref := n.primaryIPv6EndpointRLocked(remoteAddr)
- n.mu.RUnlock()
- return ref
-}
-
-// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-//
-// n.mu MUST be read locked.
-func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
-
- if len(primaryAddrs) == 0 {
- return nil
+ if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ return err
}
- // Create a candidate set of available addresses we can potentially use as a
- // source address.
- cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
- for _, r := range primaryAddrs {
- // If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- addr := r.address()
- scope, err := header.ScopeForIPv6Address(addr)
- if err != nil {
- // Should never happen as we got r from the primary IPv6 endpoint list and
- // ScopeForIPv6Address only returns an error if addr is not an IPv6
- // address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
- }
-
- cs = append(cs, ipv6AddrCandidate{
- ref: r,
- scope: scope,
- })
- }
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
- remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
- if err != nil {
- // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+// WritePackets implements NetworkLinkEndpoint.
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
+ // is being peformed like WritePacket.
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
}
- // Sort the addresses as per RFC 6724 section 5 rules 1-3.
- //
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
- sort.Slice(cs, func(i, j int) bool {
- sa := cs[i]
- sb := cs[j]
-
- // Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.ref.address() == remoteAddr {
- return true
- }
- if sb.ref.address() == remoteAddr {
- return false
- }
-
- // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
- if sa.scope < sb.scope {
- return sa.scope >= remoteScope
- } else if sb.scope < sa.scope {
- return sb.scope < remoteScope
- }
-
- // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
- if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
- // If sa is not deprecated, it is preferred over sb.
- return sbDep
- }
-
- // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
- if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp {
- return saTemp
- }
-
- // sa and sb are equal, return the endpoint that is closest to the front of
- // the primary endpoint list.
- return i < j
- })
-
- // Return the most preferred address that can have its reference count
- // incremented.
- for _, c := range cs {
- if r := c.ref; r.tryIncRef() {
- return r
- }
- }
+ n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return writtenPackets, err
+}
- return nil
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.mu.spoofing = enable
+ n.mu.Unlock()
}
-// hasPermanentAddrLocked returns true if n has a permanent (including currently
-// tentative) address, addr.
-func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+ ep, ok := n.networkEndpoints[protocol]
if !ok {
- return false
+ return nil
}
- kind := ref.getKind()
-
- return kind == permanent || kind == permanentTentative
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
-type getRefBehaviour int
+type getAddressBehaviour int
const (
// spoofing indicates that the NIC's spoofing flag should be observed when
- // getting a NIC's referenced network endpoint.
- spoofing getRefBehaviour = iota
+ // getting a NIC's address endpoint.
+ spoofing getAddressBehaviour = iota
// promiscuous indicates that the NIC's promiscuous flag should be observed
- // when getting a NIC's referenced network endpoint.
+ // when getting a NIC's address endpoint.
promiscuous
)
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
+func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
}
-// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address.
+// getAddressEpOrCreateTemp returns the address endpoint for the given protocol
+// and address.
//
// If none exists a temporary one may be created if we are in promiscuous mode
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
@@ -641,9 +354,8 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
//
// If the address is the IPv4 broadcast address for an endpoint's network, that
// endpoint will be returned.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
+func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint {
n.mu.RLock()
-
var spoofingOrPromiscuous bool
switch tempRef {
case spoofing:
@@ -651,274 +363,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
}
-
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // An endpoint with this id exists, check if it can be used and return it.
- if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
- n.mu.RUnlock()
- return nil
- }
-
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
- }
-
- if protocol == header.IPv4ProtocolNumber {
- if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
- n.mu.RUnlock()
- return ref
- }
- }
n.mu.RUnlock()
-
- if !spoofingOrPromiscuous {
- return nil
- }
-
- // Try again with the lock in exclusive mode. If we still can't get the
- // endpoint, create a new "temporary" endpoint. It will only exist while
- // there's a route through it.
- n.mu.Lock()
- ref := n.getRefOrCreateTempLocked(protocol, address, peb)
- n.mu.Unlock()
- return ref
+ return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb)
}
-// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the
-// broadcast address for the endpoint's network or an address in the endpoint's
-// subnet if the NIC is a loopback interface. This matches linux behaviour.
-//
-// n.mu MUST be read or write locked.
-func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint {
- for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses or considers the loopback
- // interface bound to an address's whole subnet (on linux).
- if ref.protocol != header.IPv4ProtocolNumber {
- continue
- }
-
- subnet := ref.addrWithPrefix().Subnet()
- if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() {
- return ref
- }
+// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
+// is passed to indicate whether or not we should generate temporary endpoints.
+func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ if ep, ok := n.networkEndpoints[protocol]; ok {
+ return ep.AcquireAssignedAddress(address, createTemp, peb)
}
return nil
}
-/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
-/// and returns a temporary endpoint.
-//
-// If the address is the IPv4 broadcast address for an endpoint's network, that
-// endpoint will be returned.
-//
-// n.mu must be write locked.
-func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // No need to check the type as we are ok with expired endpoints at this
- // point.
- if ref.tryIncRef() {
- return ref
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once the
- // lock is released. Remove it here so we can create a new (temporary) one.
- // The removal logic waiting for the lock handles this case.
- n.removeEndpointLocked(ref)
- }
-
- if protocol == header.IPv4ProtocolNumber {
- if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
- return ref
- }
- }
-
- // Add a new temporary endpoint.
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return nil
- }
- ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb, temporary, static, false)
- return ref
-}
-
-// addAddressLocked adds a new protocolAddress to n.
-//
-// If n already has the address in a non-permanent state, and the kind given is
-// permanent, that address will be promoted in place and its properties set to
-// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
- // TODO(b/141022673): Validate IP addresses before adding them.
-
- // Sanity check.
- id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.mu.endpoints[id]; ok {
- // Endpoint already exists.
- if kind != permanent {
- return nil, tcpip.ErrDuplicateAddress
- }
- switch ref.getKind() {
- case permanentTentative, permanent:
- // The NIC already have a permanent endpoint with that address.
- return nil, tcpip.ErrDuplicateAddress
- case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect the new peb,
- // configType and deprecated status.
- if ref.tryIncRef() {
- // TODO(b/147748385): Perform Duplicate Address Detection when promoting
- // an IPv6 endpoint to permanent.
- ref.setKind(permanent)
- ref.deprecated = deprecated
- ref.configType = configType
-
- refs := n.mu.primary[ref.protocol]
- for i, r := range refs {
- if r == ref {
- switch peb {
- case CanBePrimaryEndpoint:
- return ref, nil
- case FirstPrimaryEndpoint:
- if i == 0 {
- return ref, nil
- }
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- case NeverPrimaryEndpoint:
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- return ref, nil
- }
- }
- }
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- return ref, nil
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once
- // the lock is released. Remove it here so we can create a new
- // (permanent) one. The removal logic waiting for the lock handles this
- // case.
- n.removeEndpointLocked(ref)
- }
- }
-
+// addAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ return tcpip.ErrUnknownProtocol
}
- isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
-
- // If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process if the NIC is
- // enabled. If the NIC is not enabled, DAD will be started when the NIC is
- // enabled.
- if isIPv6Unicast && kind == permanent {
- kind = permanentTentative
+ addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ if err == nil {
+ // We have no need for the address endpoint.
+ addressEndpoint.DecRef()
}
-
- ref := &referencedNetworkEndpoint{
- refs: 1,
- addr: protocolAddress.AddressWithPrefix,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
- configType: configType,
- deprecated: deprecated,
- }
-
- // Set up resolver if link address resolution exists for this protocol.
- if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
- ref.linkCache = n.stack
- ref.linkRes = linkRes
- }
- }
-
- // If we are adding an IPv6 unicast address, join the solicited-node
- // multicast address.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
- if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
- return nil, err
- }
- }
-
- n.mu.endpoints[id] = ref
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
- if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
- if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
- return nil, err
- }
- }
-
- return ref, nil
-}
-
-// AddAddress adds a new address to n, so that it starts accepting packets
-// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
- // Add the endpoint.
- n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
- n.mu.Unlock()
-
return err
}
-// AllAddresses returns all addresses (primary and non-primary) associated with
+// allPermanentAddresses returns all permanent addresses associated with
// this NIC.
-func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
- for _, ref := range n.mu.endpoints {
- // Don't include tentative, expired or temporary endpoints to
- // avoid confusion and prevent the caller from using those.
- switch ref.getKind() {
- case permanentExpired, temporary:
- continue
+func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
+ var addrs []tcpip.ProtocolAddress
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PermanentAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ref.protocol,
- AddressWithPrefix: ref.addrWithPrefix(),
- })
}
return addrs
}
-// PrimaryAddresses returns the primary addresses associated with this NIC.
-func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
+// primaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.mu.primary {
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints
- // to avoid confusion and prevent the caller from using
- // those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: ref.addrWithPrefix(),
- })
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PrimaryAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
return addrs
@@ -930,147 +422,25 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- list, ok := n.mu.primary[proto]
+ ep, ok := n.networkEndpoints[proto]
if !ok {
return tcpip.AddressWithPrefix{}
}
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints to avoid confusion
- // and prevent the caller from using those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- if !ref.deprecated {
- return ref.addrWithPrefix()
- }
-
- if deprecatedEndpoint == nil {
- deprecatedEndpoint = ref
- }
- }
-
- if deprecatedEndpoint != nil {
- return deprecatedEndpoint.addrWithPrefix()
- }
-
- return tcpip.AddressWithPrefix{}
-}
-
-// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
-// by peb.
-//
-// n MUST be locked.
-func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
- switch peb {
- case CanBePrimaryEndpoint:
- n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
- case FirstPrimaryEndpoint:
- n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
- }
-}
-
-func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
- id := NetworkEndpointID{LocalAddress: r.address()}
-
- // Nothing to do if the reference has already been replaced with a different
- // one. This happens in the case where 1) this endpoint's ref count hit zero
- // and was waiting (on the lock) to be removed and 2) the same address was
- // re-added in the meantime by removing this endpoint from the list and
- // adding a new one.
- if n.mu.endpoints[id] != r {
- return
- }
-
- if r.getKind() == permanent {
- panic("Reference count dropped to zero before being removed")
- }
-
- delete(n.mu.endpoints, id)
- refs := n.mu.primary[r.protocol]
- for i, ref := range refs {
- if ref == r {
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- refs[len(refs)-1] = nil
- break
- }
- }
+ return ep.MainAddress()
}
-func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
- n.mu.Lock()
- n.removeEndpointLocked(r)
- n.mu.Unlock()
-}
-
-func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadLocalAddress
- }
-
- kind := r.getKind()
- if kind != permanent && kind != permanentTentative {
- return tcpip.ErrBadLocalAddress
- }
-
- switch r.protocol {
- case header.IPv6ProtocolNumber:
- return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
- default:
- r.expireLocked()
- return nil
- }
-}
-
-func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
- addr := r.addrWithPrefix()
-
- isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
-
- if isIPv6Unicast {
- n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
-
- // If we are removing an address generated via SLAAC, cleanup
- // its SLAAC resources and notify the integrator.
- switch r.configType {
- case slaac:
- n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- case slaacTemp:
- n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- }
- }
-
- r.expireLocked()
-
- // At this point the endpoint is deleted.
-
- // If we are removing an IPv6 unicast address, leave the solicited-node
- // multicast address.
- //
- // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
- // multicast group may be left by user action.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr.Address)
- if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+// removeAddress removes an address from n.
+func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ continue
+ } else {
return err
}
}
- return nil
-}
-
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
- return n.removePermanentAddressLocked(addr)
+ return tcpip.ErrBadLocalAddress
}
func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
@@ -1121,91 +491,66 @@ func (n *NIC) clearNeighbors() *tcpip.Error {
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.joinGroupLocked(protocol, addr)
-}
-
-// joinGroupLocked adds a new endpoint for the given multicast address, if none
-// exists yet. Otherwise it just increments its count. n MUST be locked before
-// joinGroupLocked is called.
-func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
// outlined in RFC 3810 section 5.
- id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- if joins == 0 {
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return tcpip.ErrUnknownProtocol
- }
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
}
- n.mu.mcastJoins[id] = joins + 1
- return nil
+
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ _, err := gep.JoinGroup(addr)
+ return err
}
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.leaveGroupLocked(addr, false /* force */)
-}
+func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
-// leaveGroupLocked decrements the count for the given multicast address, and
-// when it reaches zero removes the endpoint for this address. n MUST be locked
-// before leaveGroupLocked is called.
-//
-// If force is true, then the count for the multicast addres is ignored and the
-// endpoint will be removed immediately.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
- id := NetworkEndpointID{addr}
- joins, ok := n.mu.mcastJoins[id]
+ gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- // There are no joins with this address on this NIC.
- return tcpip.ErrBadLocalAddress
+ return tcpip.ErrNotSupported
}
- joins--
- if force || joins == 0 {
- // There are no outstanding joins or we are forced to leave, clean up.
- delete(n.mu.mcastJoins, id)
- return n.removePermanentAddressLocked(addr)
+ if _, err := gep.LeaveGroup(addr); err != nil {
+ return err
}
- n.mu.mcastJoins[id] = joins
return nil
}
// isInGroup returns true if n has joined the multicast group addr.
func (n *NIC) isInGroup(addr tcpip.Address) bool {
- n.mu.RLock()
- joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
- n.mu.RUnlock()
+ for _, ep := range n.networkEndpoints {
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ continue
+ }
- return joins != 0
+ if gep.IsInGroup(addr) {
+ return true
+ }
+ }
+
+ return false
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
+ r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
-
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -1216,7 +561,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
- enabled := n.mu.enabled
+ enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
n.mu.RUnlock()
@@ -1239,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.linkEP.LinkAddress()
+ local = n.LinkEndpoint.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -1274,17 +619,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
- if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
- // The source address is one of our own, so we never should have gotten a
- // packet like this unless handleLocal is false. Loopback also calls this
- // function even though the packets didn't come from the physical interface
- // so don't drop those.
- n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
- return
+ if n.stack.handleLocal && !n.IsLoopback() {
+ if r := n.getAddress(protocol, src); r != nil {
+ r.DecRef()
+
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
}
// Loopback traffic skips the prerouting chain.
- if !n.isLoopback() {
+ if !n.IsLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
@@ -1295,8 +644,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
+ if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
+ n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
return
}
@@ -1312,39 +661,39 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.ref.nic
- n.mu.RLock()
- ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
- n.mu.RUnlock()
- if ok {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
- r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
- r.Release()
- return
+ n := r.nic
+ if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
+ if n.isValidForOutgoing(addressEndpoint) {
+ r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
+ r.RemoteLinkAddress = remote
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ addressEndpoint.DecRef()
+ r.Release()
+ return
+ }
+
+ addressEndpoint.DecRef()
}
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
- // forwarder will release route.
- return
- }
+
+ // pkt may have set its header and may not have enough headroom for
+ // link-layer header for the other link to prepend. Here we create a new
+ // packet to forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- r.Release()
- return
}
- // The link-address resolution finished immediately.
- n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -1368,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.linkEP.AddHeader(local, remote, protocol, p)
+ n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- // pkt may have set its header and may not have enough headroom for link-layer
- // header for the other link to prepend. Here we create a new packet to
- // forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // WritePacket takes ownership of fwdPkt, calculate numBytes first.
- numBytes := fwdPkt.Size()
-
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- // TODO(gvisor.dev/issue/4365): Let the caller know that the transport
- // protocol is unrecognized.
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return TransportPacketHandled
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -1498,96 +822,18 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
}
-// ID returns the identifier of n.
+// ID implements NetworkInterface.
func (n *NIC) ID() tcpip.NICID {
return n.id
}
-// Name returns the name of n.
+// Name implements NetworkInterface.
func (n *NIC) Name() string {
return n.name
}
-// Stack returns the instance of the Stack that owns this NIC.
-func (n *NIC) Stack() *Stack {
- return n.stack
-}
-
-// LinkEndpoint returns the link endpoint of n.
-func (n *NIC) LinkEndpoint() LinkEndpoint {
- return n.linkEP
-}
-
-// isAddrTentative returns true if addr is tentative on n.
-//
-// Note that if addr is not associated with n, then this function will return
-// false. It will only return true if the address is associated with the NIC
-// AND it is tentative.
-func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return false
- }
-
- return ref.getKind() == permanentTentative
-}
-
-// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
-// duplicate on a link.
-//
-// dupTentativeAddrDetected will remove the tentative address if it exists. If
-// the address was generated via SLAAC, an attempt will be made to generate a
-// new address.
-func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadAddress
- }
-
- if ref.getKind() != permanentTentative {
- return tcpip.ErrInvalidEndpointState
- }
-
- // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
- // new address will be generated for it.
- if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
- return err
- }
-
- prefix := ref.addrWithPrefix().Subnet()
-
- switch ref.configType {
- case slaac:
- n.mu.ndp.regenerateSLAACAddr(prefix)
- case slaacTemp:
- // Do not reset the generation attempts counter for the prefix as the
- // temporary address is being regenerated in response to a DAD conflict.
- n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
- }
-
- return nil
-}
-
-// setNDPConfigs sets the NDP configurations for n.
-//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
-// use default values for the erroneous values.
-func (n *NIC) setNDPConfigs(c NDPConfigurations) {
- c.validate()
-
- n.mu.Lock()
- n.mu.ndp.configs = c
- n.mu.Unlock()
-}
-
-// NUDConfigs gets the NUD configurations for n.
-func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) {
+// nudConfigs gets the NUD configurations for n.
+func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
return NUDConfigurations{}, tcpip.ErrNotSupported
}
@@ -1607,49 +853,6 @@ func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
return nil
}
-// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
-func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.handleRA(ip, ra)
-}
-
-type networkEndpointKind int32
-
-const (
- // A permanentTentative endpoint is a permanent address that is not yet
- // considered to be fully bound to an interface in the traditional
- // sense. That is, the address is associated with a NIC, but packets
- // destined to the address MUST NOT be accepted and MUST be silently
- // dropped, and the address MUST NOT be used as a source address for
- // outgoing packets. For IPv6, addresses will be of this kind until
- // NDP's Duplicate Address Detection has resolved, or be deleted if
- // the process results in detecting a duplicate address.
- permanentTentative networkEndpointKind = iota
-
- // A permanent endpoint is created by adding a permanent address (vs. a
- // temporary one) to the NIC. Its reference count is biased by 1 to avoid
- // removal when no route holds a reference to it. It is removed by explicitly
- // removing the permanent address from the NIC.
- permanent
-
- // An expired permanent endpoint is a permanent endpoint that had its address
- // removed from the NIC, and it is waiting to be removed once no more routes
- // hold a reference to it. This is achieved by decreasing its reference count
- // by 1. If its address is re-added before the endpoint is removed, its type
- // changes back to permanent and its reference count increases by 1 again.
- permanentExpired
-
- // A temporary endpoint is created for spoofing outgoing packets, or when in
- // promiscuous mode and accepting incoming packets that don't match any
- // permanent endpoint. Its reference count is not biased by 1 and the
- // endpoint is removed immediately when no more route holds a reference to
- // it. A temporary endpoint can be promoted to permanent if its address
- // is added permanently.
- temporary
-)
-
func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -1680,153 +883,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
}
-type networkEndpointConfigType int32
-
-const (
- // A statically configured endpoint is an address that was added by
- // some user-specified action (adding an explicit address, joining a
- // multicast group).
- static networkEndpointConfigType = iota
-
- // A SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4862 section 5.5.3.
- slaac
-
- // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
- // not expected to be valid (or preferred) forever; hence the term temporary.
- slaacTemp
-)
-
-type referencedNetworkEndpoint struct {
- ep NetworkEndpoint
- addr tcpip.AddressWithPrefix
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
-
- // linkCache is set if link address resolution is enabled for this
- // protocol. Set to nil otherwise.
- linkCache LinkAddressCache
-
- // linkRes is set if link address resolution is enabled for this protocol.
- // Set to nil otherwise.
- linkRes LinkAddressResolver
-
- // refs is counting references held for this endpoint. When refs hits zero it
- // triggers the automatic removal of the endpoint from the NIC.
- refs int32
-
- // networkEndpointKind must only be accessed using {get,set}Kind().
- kind networkEndpointKind
-
- // configType is the method that was used to configure this endpoint.
- // This must never change except during endpoint creation and promotion to
- // permanent.
- configType networkEndpointConfigType
-
- // deprecated indicates whether or not the endpoint should be considered
- // deprecated. That is, when deprecated is true, other endpoints that are not
- // deprecated should be preferred.
- deprecated bool
-}
-
-func (r *referencedNetworkEndpoint) address() tcpip.Address {
- return r.addr.Address
-}
-
-func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
- return r.addr
-}
-
-func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
- return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
-}
-
-func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
- atomic.StoreInt32((*int32)(&r.kind), int32(kind))
-}
-
// isValidForOutgoing returns true if the endpoint can be used to send out a
// packet. It requires the endpoint to not be marked expired (i.e., its address)
// has been removed) unless the NIC is in spoofing mode, or temporary.
-func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- r.nic.mu.RLock()
- defer r.nic.mu.RUnlock()
-
- return r.isValidForOutgoingRLocked()
-}
-
-// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
-// r.nic.mu to be read locked.
-func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- if !r.nic.mu.enabled {
- return false
- }
-
- return r.isAssignedRLocked(r.nic.mu.spoofing)
-}
-
-// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
-//
-// r.nic.mu must be read locked.
-func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
- switch r.getKind() {
- case permanentTentative:
- return false
- case permanentExpired:
- return spoofingOrPromiscuous
- default:
- return true
- }
-}
-
-// expireLocked decrements the reference count and marks the permanent endpoint
-// as expired.
-func (r *referencedNetworkEndpoint) expireLocked() {
- r.setKind(permanentExpired)
- r.decRefLocked()
-}
-
-// decRef decrements the ref count and cleans up the endpoint once it reaches
-// zero.
-func (r *referencedNetworkEndpoint) decRef() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpoint(r)
- }
-}
-
-// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpointLocked(r)
- }
-}
-
-// incRef increments the ref count. It must only be called when the caller is
-// known to be holding a reference to the endpoint, otherwise tryIncRef should
-// be used.
-func (r *referencedNetworkEndpoint) incRef() {
- atomic.AddInt32(&r.refs, 1)
-}
-
-// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
-// not zero. That is, it will increment the count if the endpoint is still
-// alive, and do nothing if it has already been clean up.
-func (r *referencedNetworkEndpoint) tryIncRef() bool {
- for {
- v := atomic.LoadInt32(&r.refs)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
- return true
- }
- }
-}
-
-// stack returns the Stack instance that owns the underlying endpoint.
-func (r *referencedNetworkEndpoint) stack() *Stack {
- return r.nic.stack
+func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+ return n.Enabled() && ep.IsAssigned(spoofing)
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index bc9c9881a..97a96af62 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,96 +15,39 @@
package stack
import (
- "math"
"testing"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
-// A LinkEndpoint that throws away outgoing packets.
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
//
-// We use this instead of the channel endpoint as the channel package depends on
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
-type testLinkEndpoint struct {
- dispatcher NetworkDispatcher
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (e *testLinkEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*testLinkEndpoint) MTU() uint32 {
- return math.MaxUint16
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
- return CapabilityResolutionRequired
-}
+type testIPv6Endpoint struct {
+ AddressableEndpointState
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*testLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
+ nic NetworkInterface
+ protocol *testIPv6Protocol
-// LinkAddress returns the link address of this endpoint.
-func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
+ invalidatedRtr tcpip.Address
}
-// Wait implements LinkEndpoint.Wait.
-func (*testLinkEndpoint) Wait() {}
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) Enable() *tcpip.Error {
return nil
}
-// WritePackets implements LinkEndpoint.WritePackets.
-func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Our tests don't use this so we don't support it.
- return 0, tcpip.ErrNotSupported
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // Our tests don't use this so we don't support it.
- return tcpip.ErrNotSupported
-}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
- panic("not implemented")
+func (*testIPv6Endpoint) Enabled() bool {
+ return true
}
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- panic("not implemented")
-}
-
-var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
-
-// An IPv6 NetworkEndpoint that throws away outgoing packets.
-//
-// We use this instead of ipv6.endpoint because the ipv6 package depends on
-// the stack package which this test lives in, causing a cyclic dependency.
-type testIPv6Endpoint struct {
- nicID tcpip.NICID
- linkEP LinkEndpoint
- protocol *testIPv6Protocol
-}
+func (*testIPv6Endpoint) Disable() {}
// DefaultTTL implements NetworkEndpoint.DefaultTTL.
func (*testIPv6Endpoint) DefaultTTL() uint8 {
@@ -113,17 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.linkEP.MTU() - header.IPv6MinimumSize
-}
-
-// Capabilities implements NetworkEndpoint.Capabilities.
-func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+ return e.nic.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -144,23 +82,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
return tcpip.ErrNotSupported
}
-// NICID implements NetworkEndpoint.NICID.
-func (e *testIPv6Endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
// HandlePacket implements NetworkEndpoint.HandlePacket.
func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
}
// Close implements NetworkEndpoint.Close.
-func (*testIPv6Endpoint) Close() {}
+func (e *testIPv6Endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
+func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.invalidatedRtr = rtr
+}
+
var _ NetworkProtocol = (*testIPv6Protocol)(nil)
// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
@@ -192,12 +131,13 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
- return &testIPv6Endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
+ e := &testIPv6Endpoint{
+ nic: nic,
protocol: p,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
// SetOption implements NetworkProtocol.SetOption.
@@ -241,42 +181,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd
return "", false
}
-func newTestIPv6Protocol(*Stack) NetworkProtocol {
- return &testIPv6Protocol{}
-}
-
-// Test the race condition where a NIC is removed and an RS timer fires at the
-// same time.
-func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
- const (
- nicID = 1
-
- maxRtrSolicitations = 5
- )
-
- e := testLinkEndpoint{}
- s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{newTestIPv6Protocol},
- NDPConfigs: NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: minimumRtrSolicitationInterval,
- },
- })
-
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
- }
-
- s.mu.Lock()
- // Wait for the router solicitation timer to fire and block trying to obtain
- // the stack lock when doing link address resolution.
- time.Sleep(minimumRtrSolicitationInterval * 2)
- if err := s.removeNICLocked(nicID); err != nil {
- t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
- }
- s.mu.Unlock()
-}
-
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index a7d9d59fa..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
@@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
return newPk
}
+// Network returns the network header as a header.Network.
+//
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go
index 3eff141e6..f838eda8d 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
- nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-type forwardQueue struct {
+// packetsPendingLinkResolution is a queue of packets pending link resolution.
+//
+// Once link resolution completes successfully, the packets will be written.
+type packetsPendingLinkResolution struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]*pendingPacket
+ packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func newForwardQueue() *forwardQueue {
- return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+func (f *packetsPendingLinkResolution) init() {
+ f.Lock()
+ defer f.Unlock()
+ f.packets = make(map[<-chan struct{}][]pendingPacket)
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- shouldWait := false
-
+func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
f.Lock()
+ defer f.Unlock()
+
packets, ok := f.packets[ch]
- if !ok {
- shouldWait = true
- }
- for len(packets) == maxPendingPacketsPerResolution {
+ if len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
+ packets[0] = pendingPacket{}
packets = packets[1:]
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
+
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
- f.packets[ch] = append(packets, &pendingPacket{
- nic: n,
+
+ f.packets[ch] = append(packets, pendingPacket{
route: r,
- proto: protocol,
+ proto: proto,
pkt: pkt,
})
- f.Unlock()
- if !shouldWait {
+ if ok {
return
}
// Wait for the link-address resolution to complete.
- // Start a goroutine with a forwarding-cancel channel so that we can
- // limit the maximum number of goroutines running concurrently.
- cancel := f.newCancelChannel()
+ cancel := f.newCancelChannelLocked()
go func() {
cancelled := false
select {
@@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
}
f.Lock()
- packets := f.packets[ch]
+ packets, ok := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
+ if !ok {
+ panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
+ }
+
for _, p := range packets {
if cancelled {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
@@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *forwardQueue) newCancelChannel() chan struct{} {
- f.Lock()
- defer f.Unlock()
-
+func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
+ f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 780a5ebde..defb9129b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -152,10 +154,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -206,6 +208,10 @@ const (
// transport layer and callers need not take any further action.
TransportPacketHandled TransportPacketDisposition = iota
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
// TransportPacketDestinationPortUnreachable indicates that there weren't any
// listeners interested in the packet and the transport protocol has no means
// to notify the sender.
@@ -259,9 +265,252 @@ type NetworkHeaderParams struct {
TOS uint8
}
+// GroupAddressableEndpoint is an endpoint that supports group addressing.
+//
+// An endpoint is considered to support group addressing when one or more
+// endpoints may associate themselves with the same identifier (group address).
+type GroupAddressableEndpoint interface {
+ // JoinGroup joins the spcified group.
+ //
+ // Returns true if the group was newly joined.
+ JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // LeaveGroup attempts to leave the specified group.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
+ LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // IsInGroup returns true if the endpoint is a member of the specified group.
+ IsInGroup(group tcpip.Address) bool
+}
+
+// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary
+// behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, they are ordered by recency.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+// AddressConfigType is the method used to add an address.
+type AddressConfigType int
+
+const (
+ // AddressConfigStatic is a statically configured address endpoint that was
+ // added by some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ AddressConfigStatic AddressConfigType = iota
+
+ // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862
+ // section 5.5.3.
+ AddressConfigSlaac
+
+ // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as
+ // per RFC 4941. Temporary SLAAC addresses are short-lived and are not
+ // to be valid (or preferred) forever; hence the term temporary.
+ AddressConfigSlaacTemp
+)
+
+// AssignableAddressEndpoint is a reference counted address endpoint that may be
+// assigned to a NetworkEndpoint.
+type AssignableAddressEndpoint interface {
+ // AddressWithPrefix returns the endpoint's address.
+ AddressWithPrefix() tcpip.AddressWithPrefix
+
+ // IsAssigned returns whether or not the endpoint is considered bound
+ // to its NetworkEndpoint.
+ IsAssigned(allowExpired bool) bool
+
+ // IncRef increments this endpoint's reference count.
+ //
+ // Returns true if it was successfully incremented. If it returns false, then
+ // the endpoint is considered expired and should no longer be used.
+ IncRef() bool
+
+ // DecRef decrements this endpoint's reference count.
+ DecRef()
+}
+
+// AddressEndpoint is an endpoint representing an address assigned to an
+// AddressableEndpoint.
+type AddressEndpoint interface {
+ AssignableAddressEndpoint
+
+ // GetKind returns the address kind for this endpoint.
+ GetKind() AddressKind
+
+ // SetKind sets the address kind for this endpoint.
+ SetKind(AddressKind)
+
+ // ConfigType returns the method used to add the address.
+ ConfigType() AddressConfigType
+
+ // Deprecated returns whether or not this endpoint is deprecated.
+ Deprecated() bool
+
+ // SetDeprecated sets this endpoint's deprecated status.
+ SetDeprecated(bool)
+}
+
+// AddressKind is the kind of of an address.
+//
+// See the values of AddressKind for more details.
+type AddressKind int
+
+const (
+ // PermanentTentative is a permanent address endpoint that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses are of this kind until NDP's
+ // Duplicate Address Detection (DAD) resolves. If DAD fails, the address
+ // is removed.
+ PermanentTentative AddressKind = iota
+
+ // Permanent is a permanent endpoint (vs. a temporary one) assigned to the
+ // NIC. Its reference count is biased by 1 to avoid removal when no route
+ // holds a reference to it. It is removed by explicitly removing the address
+ // from the NIC.
+ Permanent
+
+ // PermanentExpired is a permanent endpoint that had its address removed from
+ // the NIC, and it is waiting to be removed once no references to it are held.
+ //
+ // If the address is re-added before the endpoint is removed, its type
+ // changes back to Permanent.
+ PermanentExpired
+
+ // Temporary is an endpoint, created on a one-off basis to temporarily
+ // consider the NIC bound an an address that it is not explictiy bound to
+ // (such as a permanent address). Its reference count must not be biased by 1
+ // so that the address is removed immediately when references to it are no
+ // longer held.
+ //
+ // A temporary endpoint may be promoted to permanent if the address is added
+ // permanently.
+ Temporary
+)
+
+// IsPermanent returns true if the AddressKind represents a permanent address.
+func (k AddressKind) IsPermanent() bool {
+ switch k {
+ case Permanent, PermanentTentative:
+ return true
+ case Temporary, PermanentExpired:
+ return false
+ default:
+ panic(fmt.Sprintf("unrecognized address kind = %d", k))
+ }
+}
+
+// AddressableEndpoint is an endpoint that supports addressing.
+//
+// An endpoint is considered to support addressing when the endpoint may
+// associate itself with an identifier (address).
+type AddressableEndpoint interface {
+ // AddAndAcquirePermanentAddress adds the passed permanent address.
+ //
+ // Returns tcpip.ErrDuplicateAddress if the address exists.
+ //
+ // Acquires and returns the AddressEndpoint for the added address.
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error)
+
+ // RemovePermanentAddress removes the passed address if it is a permanent
+ // address.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed
+ // permanent address.
+ RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
+
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
+ // that is considered bound to the endpoint, optionally creating a temporary
+ // endpoint if requested and no existing address exists.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if the specified address is not local to this endpoint.
+ AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
+
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
+ //
+ // If allowExpired is true, expired addresses may be returned.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+
+ // PrimaryAddresses returns the primary addresses.
+ PrimaryAddresses() []tcpip.AddressWithPrefix
+
+ // PermanentAddresses returns all the permanent addresses.
+ PermanentAddresses() []tcpip.AddressWithPrefix
+}
+
+// NDPEndpoint is a network endpoint that supports NDP.
+type NDPEndpoint interface {
+ NetworkEndpoint
+
+ // InvalidateDefaultRouter invalidates a default router discovered through
+ // NDP.
+ InvalidateDefaultRouter(tcpip.Address)
+}
+
+// NetworkInterface is a network interface.
+type NetworkInterface interface {
+ NetworkLinkEndpoint
+
+ // ID returns the interface's ID.
+ ID() tcpip.NICID
+
+ // IsLoopback returns true if the interface is a loopback interface.
+ IsLoopback() bool
+
+ // Name returns the name of the interface.
+ //
+ // May return an empty string if the interface is not configured with a name.
+ Name() string
+
+ // Enabled returns true if the interface is enabled.
+ Enabled() bool
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
+ AddressableEndpoint
+
+ // Enable enables the endpoint.
+ //
+ // Must only be called when the stack is in a state that allows the endpoint
+ // to send and receive packets.
+ //
+ // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled.
+ Enable() *tcpip.Error
+
+ // Enabled returns true if the endpoint is enabled.
+ Enabled() bool
+
+ // Disable disables the endpoint.
+ Disable()
+
// DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
@@ -271,10 +520,6 @@ type NetworkEndpoint interface {
// minus the network endpoint max header length.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // underlying link-layer endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -295,9 +540,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // NICID returns the id of the NIC this endpoint belongs to.
- NICID() tcpip.NICID
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
@@ -312,6 +554,17 @@ type NetworkEndpoint interface {
NetworkProtocolNumber() tcpip.NetworkProtocolNumber
}
+// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
+type ForwardingNetworkProtocol interface {
+ NetworkProtocol
+
+ // Forwarding returns the forwarding configuration.
+ Forwarding() bool
+
+ // SetForwarding sets the forwarding configuration.
+ SetForwarding(bool)
+}
+
// NetworkProtocol is the interface that needs to be implemented by network
// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
type NetworkProtocol interface {
@@ -331,7 +584,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -409,22 +662,15 @@ const (
CapabilitySoftwareGSO
)
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
+// NetworkLinkEndpoint is a data-link layer that supports sending network
+// layer packets.
+type NetworkLinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -432,7 +678,7 @@ type LinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // link endpoint.
+ // endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -452,6 +698,19 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+}
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ NetworkLinkEndpoint
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
@@ -460,8 +719,8 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
- // Attach will be called with a nil dispatcher if the receiver's associated
- // NIC is being removed.
+ // Attach is called with a nil dispatcher when the endpoint's NIC is being
+ // removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 2cbbf0de8..25f80c1f8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -42,17 +42,27 @@ type Route struct {
// NetProto is the network-layer protocol.
NetProto tcpip.NetworkProtocolNumber
- // ref a reference to the network endpoint through which the route
- // starts.
- ref *referencedNetworkEndpoint
-
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+
+ // nic is the NIC the route goes through.
+ nic *NIC
+
+ // addressEndpoint is the local address this route is associated with.
+ addressEndpoint AssignableAddressEndpoint
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
}
// makeRoute initializes a new route. It takes ownership of the provided
-// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+// AssignableAddressEndpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
loop := PacketOut
if handleLocal && localAddr != "" && remoteAddr == localAddr {
loop = PacketLoop
@@ -62,29 +72,39 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- return Route{
+ r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: localLinkAddr,
+ LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
- ref: ref,
+ addressEndpoint: addressEndpoint,
+ nic: nic,
Loop: loop,
}
+
+ if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ r.linkRes = linkRes
+ r.linkCache = r.nic.stack
+ }
+ }
+
+ return r
}
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.ref.ep.NICID()
+ return r.nic.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.ref.ep.MaxHeaderLength()
+ return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.ref.nic.stack.Stats()
+ return r.nic.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -95,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.ref.ep.Capabilities()
+ return r.nic.LinkEndpoint.Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
-// ResolveWith immediately resolves a route with the specified remote link
-// address.
-func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
-}
-
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -138,8 +152,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if r.ref.nic.neigh != nil {
- entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker)
+ if neigh := r.nic.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -147,7 +161,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -162,12 +176,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if r.ref.nic.neigh != nil {
- r.ref.nic.neigh.removeWaker(nextAddr, waker)
+ if neigh := r.nic.neigh; neigh != nil {
+ neigh.removeWaker(nextAddr, waker)
return
}
- r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
@@ -175,104 +189,63 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.ref.nic.neigh != nil {
- return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == ""
+ if r.nic.neigh != nil {
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
}
- return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- err := r.ref.ep.WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- }
- return err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return n, err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
- // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Data.Size()
-
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.ref.ep.DefaultTTL()
+ return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.ref.ep.MTU()
-}
-
-// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
-// network endpoint.
-func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.ref.ep.NetworkProtocolNumber()
+ return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.ref != nil {
- r.ref.decRef()
- r.ref = nil
+ if r.addressEndpoint != nil {
+ r.addressEndpoint.DecRef()
+ r.addressEndpoint = nil
}
}
-// Clone Clone a route such that the original one can be released and the new
-// one will remain valid.
+// Clone clones the route.
func (r *Route) Clone() Route {
- if r.ref != nil {
- r.ref.incRef()
+ if r.addressEndpoint != nil {
+ _ = r.addressEndpoint.IncRef()
}
return *r
}
@@ -296,7 +269,7 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.ref.stack()
+ return r.nic.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -304,7 +277,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.ref.addrWithPrefix().Subnet()
+ subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
return subnet.IsBroadcast(addr)
}
@@ -330,7 +303,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
LocalLinkAddress: r.RemoteLinkAddress,
RemoteAddress: src,
RemoteLinkAddress: r.LocalLinkAddress,
- ref: r.ref,
Loop: r.Loop,
+ addressEndpoint: r.addressEndpoint,
+ nic: r.nic,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c22633f6b..3a07577c8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -363,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
-// NICNameFromID is a function that returns a stable name for the specified NIC,
-// even if different NIC IDs are used to refer to the same NIC in different
-// program runs. It is used when generating opaque interface identifiers (IIDs).
-// If the NIC was created with a name, it will be passed to NICNameFromID.
-//
-// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
-// generated for the same prefix on differnt NICs.
-type NICNameFromID func(tcpip.NICID, string) string
-
-// OpaqueInterfaceIdentifierOptions holds the options related to the generation
-// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
-type OpaqueInterfaceIdentifierOptions struct {
- // NICNameFromID is a function that returns a stable name for a specified NIC,
- // even if the NIC ID changes over time.
- //
- // Must be specified to generate the opaque IID.
- NICNameFromID NICNameFromID
-
- // SecretKey is a pseudo-random number used as the secret key when generating
- // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
- // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
- // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
- // change between program runs, unless explicitly changed.
- //
- // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
- // MUST NOT be modified after Stack is created.
- //
- // May be nil, but a nil value is highly discouraged to maintain
- // some level of randomness between nodes.
- SecretKey []byte
-}
-
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -402,13 +370,6 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
- // forwarding contains the whether packet forwarding is enabled or not for
- // different network protocols.
- forwarding struct {
- sync.RWMutex
- protocols map[tcpip.NetworkProtocolNumber]bool
- }
-
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
rawFactory RawFactory
@@ -461,9 +422,6 @@ type Stack struct {
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
- // ndpConfigs is the default NDP configurations used by interfaces.
- ndpConfigs NDPConfigurations
-
// nudConfigs is the default NUD configurations used by interfaces.
nudConfigs NUDConfigurations
@@ -471,15 +429,6 @@ type Stack struct {
// by the NIC's neighborCache instead of linkAddrCache.
useNeighborCache bool
- // autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
// nudDisp is the NUD event dispatcher that is used to send the netstack
// integrator NUD related events.
nudDisp NUDDispatcher
@@ -487,17 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
- // forwarder holds the packets that wait for their link-address resolutions
- // to complete, and forwards them when each resolution is done.
- forwarder *forwardQueue
+ // linkResQueue holds packets that are waiting for link resolution to
+ // complete.
+ linkResQueue packetsPendingLinkResolution
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -553,13 +494,6 @@ type Options struct {
// UniqueID is an optional generator of unique identifiers.
UniqueID UniqueID
- // NDPConfigs is the default NDP configurations used by interfaces.
- //
- // By default, NDPConfigs will have a zero value for its
- // DupAddrDetectTransmits field, implying that DAD will not be performed
- // before assigning an address to a NIC.
- NDPConfigs NDPConfigurations
-
// NUDConfigs is the default NUD configurations used by interfaces.
NUDConfigs NUDConfigurations
@@ -570,24 +504,6 @@ type Options struct {
// and ClearNeighbors.
UseNeighborCache bool
- // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs.
- //
- // Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address Detection
- // is enabled, an address will only be assigned if it successfully resolves.
- // If it fails, no further attempt will be made to auto-generate an IPv6
- // link-local address.
- //
- // The generated link-local address will follow RFC 4291 Appendix A
- // guidelines.
- AutoGenIPv6LinkLocal bool
-
- // NDPDisp is the NDP event dispatcher that an integrator can provide to
- // receive NDP related events.
- NDPDisp NDPDispatcher
-
// NUDDisp is the NUD event dispatcher that an integrator can provide to
// receive NUD related events.
NUDDisp NUDDispatcher
@@ -596,31 +512,12 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface
- // identifiers (IIDs) as outlined by RFC 7217.
- OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by rand.Read().
//
// RandSource must be thread-safe.
RandSource mathrand.Source
-
- // TempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- //
- // Temporary SLAAC adresses are short-lived addresses which are unpredictable
- // and random from the perspective of other nodes on the network. It is
- // recommended that the seed be a random byte buffer of at least
- // header.IIDSize bytes to make sure that temporary SLAAC addresses are
- // sufficiently random. It should follow minimum randomness requirements for
- // security as outlined by RFC 4086.
- //
- // Note: using a nil value, the same seed across netstack program runs, or a
- // seed that is too small would reduce randomness and increase predictability,
- // defeating the purpose of temporary SLAAC addresses.
- TempIIDSeed []byte
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -653,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -668,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(e.ID.LocalAddress) {
+ switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -680,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == e.NetProto:
- case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ case netProto == t.NetProto:
+ case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -723,36 +620,27 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
- // Make sure opts.NDPConfigs contains valid values only.
- opts.NDPConfigs.validate()
-
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- ndpConfigs: opts.NDPConfigs,
- nudConfigs: opts.NUDConfigs,
- useNeighborCache: opts.UseNeighborCache,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
- uniqueIDGenerator: opts.UniqueID,
- ndpDisp: opts.NDPDisp,
- nudDisp: opts.NUDDisp,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- forwarder: newForwardQueue(),
- randomGenerator: mathrand.New(randSrc),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ useNeighborCache: opts.UseNeighborCache,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -764,7 +652,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
- s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool)
+ s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -884,42 +772,37 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables packet forwarding between NICs.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
- s.forwarding.Lock()
- defer s.forwarding.Unlock()
-
- // If this stack does not support the protocol, do nothing.
- if _, ok := s.networkProtocols[protocol]; !ok {
- return
+// SetForwarding enables or disables packet forwarding between NICs for the
+// passed protocol.
+func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
}
- // If the forwarding value for this protocol hasn't changed then do
- // nothing.
- if forwarding := s.forwarding.protocols[protocol]; forwarding == enable {
- return
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return tcpip.ErrNotSupported
}
- s.forwarding.protocols[protocol] = enable
+ forwardingProtocol.SetForwarding(enable)
+ return nil
+}
- if protocol == header.IPv6ProtocolNumber {
- if enable {
- for _, nic := range s.nics {
- nic.becomeIPv6Router()
- }
- } else {
- for _, nic := range s.nics {
- nic.becomeIPv6Host()
- }
- }
+// Forwarding returns true if packet forwarding between NICs is enabled for the
+// passed protocol.
+func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return false
}
-}
-// Forwarding returns if packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- s.forwarding.RLock()
- defer s.forwarding.RUnlock()
- return s.forwarding.protocols[protocol]
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return false
+ }
+
+ return forwardingProtocol.Forwarding()
}
// SetRouteTable assigns the route table to be used by this stack. It
@@ -954,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewEndpoint(s, network, waiterQueue)
+ return t.proto.NewEndpoint(network, waiterQueue)
}
// NewRawEndpoint creates a new raw transport layer endpoint of the given
@@ -974,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewRawEndpoint(s, network, waiterQueue)
+ return t.proto.NewRawEndpoint(network, waiterQueue)
}
// NewPacketEndpoint creates a new packet endpoint listening for the given
@@ -1045,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetNICByName gets the NIC specified by name.
-func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+// GetLinkEndpointByName gets the link endpoint specified by name.
+func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic, true
+ return nic.LinkEndpoint
}
}
- return nil, false
+ return nil
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1081,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- return nic.disable()
+ nic.disable()
+ return nil
}
// CheckNIC checks if a NIC is usable.
@@ -1094,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
- return nic.enabled()
+ return nic.Enabled()
}
// RemoveNIC removes NIC and all related routes from the network stack.
@@ -1172,19 +1056,19 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.enabled(),
+ Running: nic.Enabled(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.isLoopback(),
+ Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.PrimaryAddresses(),
+ LinkAddress: nic.LinkEndpoint.LinkAddress(),
+ ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.linkEP.MTU(),
+ MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.linkEP.ARPHardwareType(),
+ ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
}
return nics
@@ -1243,7 +1127,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, peb)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1253,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ return nic.removeAddress(addr)
}
return tcpip.ErrUnknownNICID
@@ -1267,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
for id, nic := range s.nics {
- nics[id] = nic.AllAddresses()
+ nics[id] = nic.allPermanentAddresses()
}
return nics
}
@@ -1289,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
if len(localAddr) == 0 {
return nic.primaryEndpoint(netProto, remoteAddr)
}
@@ -1306,9 +1190,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
+ if nic, ok := s.nics[id]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
}
}
} else {
@@ -1316,20 +1200,20 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
- remoteAddr = ref.address()
+ remoteAddr = addressEndpoint.AddressWithPrefix().Address
}
- r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
+ r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
if len(route.Gateway) > 0 {
if needRoute {
r.NextHop = route.Gateway
}
- } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
@@ -1367,21 +1251,20 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
return 0
}
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref == nil {
+ addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
return 0
}
- ref.decRef()
+ addressEndpoint.DecRef()
return nic.id
}
// Go through all the NICs.
for _, nic := range s.nics {
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref != nil {
- ref.decRef()
+ if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
return nic.id
}
}
@@ -1440,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1656,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.linkEP.Wait()
+ n.LinkEndpoint.Wait()
}
}
@@ -1744,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.linkEP.LinkAddress(),
+ SrcAddr: nic.LinkEndpoint.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1753,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
return err
}
@@ -1770,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
return err
}
@@ -1850,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
- return nic.leaveGroup(multicastAddr)
+ return nic.leaveGroup(protocol, multicastAddr)
}
return tcpip.ErrUnknownNICID
}
@@ -1902,53 +1785,18 @@ func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
-// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
-//
-// Note that if addr is not associated with a NIC with id ID, then this
-// function will return false. It will only return true if the address is
-// associated with the NIC AND it is tentative.
-func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic, ok := s.nics[id]
- if !ok {
- return false, tcpip.ErrUnknownNICID
- }
-
- return nic.isAddrTentative(addr), nil
-}
-
-// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
-// tentative addr on it is a duplicate on a link.
-func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- nic, ok := s.nics[id]
- if !ok {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.dupTentativeAddrDetected(addr)
-}
-
-// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
-// with ID id to c.
-//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
-// use default values for the erroneous values.
-func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
+// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol
+// number installed on the specified NIC.
+func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
- nic, ok := s.nics[id]
+ nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return nil, tcpip.ErrUnknownNICID
}
- nic.setNDPConfigs(c)
- return nil
+ return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@@ -1961,7 +1809,7 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err
return NUDConfigurations{}, tcpip.ErrUnknownNICID
}
- return nic.NUDConfigs()
+ return nic.nudConfigs()
}
// SetNUDConfigurations sets the per-interface NUD configurations.
@@ -1980,22 +1828,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip
return nic.setNUDConfigs(c)
}
-// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
-// message that it needs to handle.
-func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- nic, ok := s.nics[id]
- if !ok {
- return tcpip.ErrUnknownNICID
- }
-
- nic.handleNDPRA(ip, ra)
-
- return nil
-}
-
// Seed returns a 32 bit value that can be used as a seed value for port
// picking, ISN generation etc.
//
@@ -2037,16 +1869,12 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
defer s.mu.RUnlock()
for _, nic := range s.nics {
- id := NetworkEndpointID{address}
-
- if ref, ok := nic.mu.endpoints[id]; ok {
- nic.mu.RLock()
- defer nic.mu.RUnlock()
-
- // An endpoint with this id exists, check if it can be
- // used and return it.
- return ref.ep, nil
+ addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ continue
}
+ addressEndpoint.DecRef()
+ return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}
@@ -2063,3 +1891,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
return nic.Name()
}
+
+// NewJob returns a new tcpip.Job using the stack's clock.
+func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index c205650fe..38994cca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"math"
- "net"
"sort"
"testing"
"time"
@@ -29,12 +28,12 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -68,18 +67,40 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
- nicID tcpip.NICID
+ stack.AddressableEndpointState
+
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
+ nic stack.NetworkInterface
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
}
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
+ return nil
}
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
+}
+
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
+}
+
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -111,17 +132,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+ return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -144,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -156,7 +173,9 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack
return tcpip.ErrNotSupported
}
-func (*fakeNetworkEndpoint) Close() {}
+func (f *fakeNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
@@ -165,6 +184,11 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -187,13 +211,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
- return &fakeNetworkEndpoint{
- nicID: nicID,
+func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &fakeNetworkEndpoint{
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: ep,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
@@ -231,6 +256,20 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -2063,7 +2102,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -2213,7 +2252,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName string
autoGen bool
linkAddr tcpip.LinkAddress
- iidOpts stack.OpaqueInterfaceIdentifierOptions
+ iidOpts ipv6.OpaqueInterfaceIdentifierOptions
shouldGen bool
expectedAddr tcpip.Address
}{
@@ -2229,7 +2268,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: false,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2274,7 +2313,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: true,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2286,7 +2325,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
{
name: "OIID Empty MAC and empty nicName",
autoGen: true,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:1],
},
@@ -2298,7 +2337,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test",
autoGen: true,
linkAddr: "\x01\x02\x03",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:2],
},
@@ -2310,7 +2349,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test2",
autoGen: true,
linkAddr: "\x01\x02\x03\x04\x05\x06",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:3],
},
@@ -2322,7 +2361,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test3",
autoGen: true,
linkAddr: "\x00\x00\x00\x00\x00\x00",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
},
shouldGen: true,
@@ -2336,10 +2375,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
+ })},
}
e := channel.New(0, 1280, test.linkAddr)
@@ -2411,15 +2451,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
tests := []struct {
name string
- opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions
}{
{
name: "IID From MAC",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{},
},
{
name: "Opaque IID",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
},
@@ -2430,9 +2470,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ })},
}
e := loopback.New()
@@ -2461,12 +2502,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
@@ -2813,14 +2855,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDispatcher{},
+ })},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDispatcher{},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -3059,12 +3102,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
dadC: make(chan ndpDADEvent),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(dadTransmits, 1280, linkAddr1)
@@ -3454,48 +3498,130 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
-func TestResolveWith(t *testing.T) {
+// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
+// associated address is removed should not cause a panic.
+func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
const (
- unspecifiedNICID = 0
- nicID = 1
+ nicID = 1
+ localAddr = tcpip.Address("\x01")
+ remoteAddr = tcpip.Address("\x02")
)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
+
ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- },
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err)
}
+ // Should not panic.
defer r.Release()
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err)
}
+}
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
+func TestGetNetworkEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ factories := make([]stack.NetworkProtocolFactory, 0, len(tests))
+ for _, test := range tests {
+ factories = append(factories, test.protoFactory)
+ }
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: factories,
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep, err := s.GetNetworkEndpoint(nicID, test.protoNum)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err)
+ }
+
+ if got := ep.NetworkProtocolNumber(); got != test.protoNum {
+ t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum)
+ }
+ })
+ }
+}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 0774b5382..35e5b1a2e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return true
}
- // If the packet is a TCP packet with a non-unicast source or destination
- // address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
+ // If the packet is a TCP packet with a unspecified source or non-unicast
+ // destination address, then do nothing further and instruct the caller to do
+ // the same. The network layer handles address validation for specified source
+ // addresses.
+ if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -681,10 +683,6 @@ func isInboundMulticastOrBroadcast(r *Route) bool {
return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isInboundUnicast(r *Route) bool {
- return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
-}
-
-func isOutboundUnicast(r *Route) bool {
- return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
+func isSpecified(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8aae60740..62ab6d92f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -39,7 +39,7 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
- stack *stack.Stack
+
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
@@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Abort() {
@@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
+ r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}
@@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
+ if err := f.proto.stack.RegisterTransportEndpoint(
a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
@@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- stack: f.stack,
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct {
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
+ stack *stack.Stack
+
packetCount int
controlCount int
opts fakeTransportProtocolOptions
@@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -326,8 +327,8 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
return ok
}
-func fakeTransFactory(*stack.Stack) stack.TransportProtocol {
- return &fakeTransportProtocol{}
+func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
+ return &fakeTransportProtocol{stack: s}
}
func TestTransportReceive(t *testing.T) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 464608dee..c42bb0991 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -237,6 +237,14 @@ type Timer interface {
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string
+// WithPrefix returns the address with a prefix that represents a point subnet.
+func (a Address) WithPrefix() AddressWithPrefix {
+ return AddressWithPrefix{
+ Address: a,
+ PrefixLen: len(a) * 8,
+ }
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -1614,9 +1622,6 @@ type UDPStats struct {
// ChecksumErrors is the number of datagrams dropped due to bad checksums.
ChecksumErrors *StatCounter
-
- // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
- InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 06c7a3cd3..a4f141253 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -6,6 +6,8 @@ go_test(
name = "integration_test",
size = "small",
srcs = [
+ "forward_test.go",
+ "link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
],
@@ -15,6 +17,8 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
new file mode 100644
index 000000000..ffd38ee1a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -0,0 +1,378 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestForwarding(t *testing.T) {
+ const (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+
+ listenPort = 8080
+ )
+
+ host1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 host2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack)
+ defer epsAndAddrs.serverEP.Close()
+ defer epsAndAddrs.clientEP.Close()
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) {
+ t.Helper()
+
+ dataPayload := tcpip.SlicePayload(data)
+ wOpts := tcpip.WriteOptions{To: to}
+ n, ch, err := ep.Write(dataPayload, wOpts)
+ if err == tcpip.ErrNoLinkAddress {
+ // Wait for link resolution to complete.
+ <-ch
+
+ n, _, err = ep.Write(dataPayload, wOpts)
+ } else if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+
+ if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+ }
+
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data, &serverAddr)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress {
+ t.Helper()
+
+ // Wait for the endpoint to be readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+
+ if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != expectedFrom {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ return addr
+ }
+
+ addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
+ // Unspecify the NIC since NIC IDs are meaningless across stacks.
+ addr.NIC = 0
+
+ data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ write(epsAndAddrs.serverEP, data, &addr)
+ addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr)
+ if addr.Port != listenPort {
+ t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
new file mode 100644
index 000000000..bf3a6f6ee
--- /dev/null
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+)
+
+// TestPing tests that two hosts can ping eachother when link resolution is
+// enabled.
+func TestPing(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ }{
+ {
+ name: "IPv4 Ping",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ {
+ name: "IPv6 Ping",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, waiterCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ // The first write should trigger link resolution.
+ icmpBuf := test.icmpBuf(t)
+ wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
+ if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress)
+ } else {
+ // Wait for link resolution to complete.
+ <-ch
+ }
+ if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(icmpBuf)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+
+ // Wait for the endpoint to be readable.
+ <-waiterCH
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.remoteAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index f35dcc084..e8caf09ba 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -16,6 +16,7 @@ package integration_test
import (
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -29,6 +30,69 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
+
+type ndpDispatcher struct{}
+
+func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
+
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return true
+}
+
+func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {}
+
+func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {}
+
+func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {}
+
+// TestInitialLoopbackAddresses tests that the loopback interface does not
+// auto-generate a link-local address when it is brought up.
+func TestInitialLoopbackAddresses(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDispatcher{},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
+ t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
+ return ""
+ },
+ },
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ nicsInfo := s.NICInfo()
+ if nicInfo, ok := nicsInfo[nicID]; !ok {
+ t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo)
+ } else if got := len(nicInfo.ProtocolAddresses); got != 0 {
+ t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses)
+ }
+}
+
// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
// itself bound to all addresses in the subnet of an assigned address.
func TestLoopbackAcceptAllInSubnet(t *testing.T) {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 72d86b5ab..4f2ca7f54 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
if src != expectedSrc {
t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7484f4ad9..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -37,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -57,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -130,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6891fd245..0aaef495d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
@@ -1219,12 +1219,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
- // Increase counter if after processing the segment we would potentially
- // advertise a zero window.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
-
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7ad894840..3bcd3923a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -248,6 +248,11 @@ type ReceiveErrors struct {
// ZeroRcvWindowState is the number of times we advertised
// a zero receive window when rcvList is full.
ZeroRcvWindowState tcpip.StatCounter
+
+ // WantZeroWindow is the number of times we wanted to advertise a
+ // zero receive window but couldn't because it would have caused
+ // the receive window's right edge to shrink.
+ WantZeroRcvWindow tcpip.StatCounter
}
// SendErrors collect segment send errors within the transport layer.
@@ -1162,7 +1167,7 @@ func (e *endpoint) cleanupLocked() {
// wndFromSpace returns the window that we can advertise based on the available
// receive buffer space.
func wndFromSpace(space int) int {
- return space / (1 << rcvAdvWndScale)
+ return space >> rcvAdvWndScale
}
// initialReceiveWindow returns the initial receive window to advertise in the
@@ -1518,6 +1523,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
+// selectWindowLocked returns the new window without checking for shrinking or scaling
+// applied.
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
+ maxWindow := wndFromSpace(e.rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvBufUsed
+
+ // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
+ // cases where we receive a lot of small segments the segment overhead is a
+ // lot higher and we can run out socket buffer space before we can fill the
+ // previous window we advertised. In cases where we receive MSS sized or close
+ // MSS sized segments we will probably run out of window space before we
+ // exhaust receive buffer.
+ newWnd := wndFromAvailable
+ if newWnd > wndFromUsedBytes {
+ newWnd = wndFromUsedBytes
+ }
+ if newWnd < 0 {
+ newWnd = 0
+ }
+ return seqnum.Size(newWnd)
+}
+
+// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+func (e *endpoint) selectWindow() (wnd seqnum.Size) {
+ e.rcvListMu.Lock()
+ wnd = e.selectWindowLocked()
+ e.rcvListMu.Unlock()
+ return wnd
+}
+
// windowCrossedACKThresholdLocked checks if the receive window to be announced
// would be under aMSS or under the window derived from half receive buffer,
// whichever smaller. This is useful as a receive side silly window syndrome
@@ -1534,7 +1571,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
+ newAvail := int(e.selectWindowLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -2099,7 +2136,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID)
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
e.UnlockUser()
if err != nil {
return err
@@ -3013,6 +3050,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
EndSequence: rc.endSequence,
FACK: rc.fack,
RTT: rc.rtt,
+ Reord: rc.reorderSeen,
}
return s
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 6a3c2c32b..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -133,6 +133,8 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
recovery tcpip.TCPRecovery
@@ -159,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -505,8 +507,9 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a TCP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
p := protocol{
+ stack: s,
sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index d969ca23a..d312b1b8b 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -29,26 +29,36 @@ import (
//
// +stateify savable
type rackControl struct {
- // xmitTime is the latest transmission timestamp of rackControl.seg.
- xmitTime time.Time `state:".(unixTime)"`
-
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
+ // dsack indicates if the connection has seen a DSACK.
+ dsack bool
+
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
+ // minRTT is the estimated minimum RTT of the connection.
+ minRTT time.Duration
+
// rtt is the RTT of the most recently delivered packet on the
// connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
rtt time.Duration
+
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
}
-// Update will update the RACK related fields when an ACK has been received.
+// update will update the RACK related fields when an ACK has been received.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
-func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
rtt := time.Now().Sub(seg.xmitTime)
// If the ACK is for a retransmitted packet, do not update if it is a
@@ -65,12 +75,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
return
}
}
- if rtt < srtt {
+ if rtt < rc.minRTT {
return
}
}
rc.rtt = rtt
+
+ // The sender can either track a simple global minimum of all RTT
+ // measurements from the connection, or a windowed min-filtered value
+ // of recent RTT measurements. This implementation keeps track of the
+ // simple global minimum of all RTTs for the connection.
+ if rtt < rc.minRTT || rc.minRTT == 0 {
+ rc.minRTT = rtt
+ }
+
// Update rc.xmitTime and rc.endSequence to the transmit time and
// ending sequence number of the packet which has been acknowledged
// most recently.
@@ -80,3 +99,26 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
rc.endSequence = endSeq
}
}
+
+// detectReorder detects if packet reordering has been observed.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 3: Detect data segment reordering.
+// To detect reordering, the sender looks for original data segments being
+// delivered out of order. To detect such cases, the sender tracks the
+// highest sequence selectively or cumulatively acknowledged in the RACK.fack
+// variable. The name "fack" stands for the most "Forward ACK" (this term is
+// adopted from [FACK]). If a never retransmitted segment that's below
+// RACK.fack is (selectively or cumulatively) acknowledged, it has been
+// delivered out of order. The sender sets RACK.reord to TRUE if such segment
+// is identified.
+func (rc *rackControl) detectReorder(seg *segment) {
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.fack.LessThan(endSeq) {
+ rc.fack = endSeq
+ return
+ }
+
+ if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
+ rc.reorderSeen = true
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 4aafb4d22..8e0b7c843 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -43,6 +43,9 @@ type receiver struct {
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
+ // rcvWUP is the rcvNxt value at the last window update sent.
+ rcvWUP seqnum.Value
+
rcvWndScale uint8
closed bool
@@ -64,6 +67,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
+ rcvWUP: irs + 1,
rcvWndScale: rcvWndScale,
lastRcvdAckTime: time.Now(),
}
@@ -84,27 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
}
+// currentWindow returns the available space in the window that was advertised
+// last to our peer.
+func (r *receiver) currentWindow() (curWnd seqnum.Size) {
+ endOfWnd := r.rcvWUP.Add(r.rcvWnd)
+ if endOfWnd.LessThan(r.rcvNxt) {
+ // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ // This can happen because we accept a large segment completely even if
+ // accepting it causes it to partially exceed the advertised window.
+ return 0
+ }
+ return r.rcvNxt.Size(endOfWnd)
+}
+
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- avail := wndFromSpace(r.ep.receiveBufferAvailable())
- acc := r.rcvNxt.Add(seqnum.Size(avail))
- newWnd := r.rcvNxt.Size(acc)
- curWnd := r.rcvNxt.Size(r.rcvAcc)
-
+ newWnd := r.ep.selectWindow()
+ curWnd := r.currentWindow()
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
// would end up dropping bytes that might already be in flight.
- if newWnd > curWnd {
- r.rcvAcc = r.rcvNxt.Add(newWnd)
+ // ==================================================== sequence space.
+ // ^ ^ ^ ^
+ // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // <=====curWnd ===>
+ // <========= newWnd > curWnd ========= >
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ // If the new window moves the right edge, then update rcvAcc.
+ r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
+ if newWnd == 0 {
+ // newWnd is zero but we can't advertise a zero as it would cause window
+ // to shrink so just increment a metric to record this event.
+ r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment()
+ }
newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
- return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
+ r.rcvWUP = r.rcvNxt
+ scaledWnd := r.rcvWnd >> r.rcvWndScale
+ if scaledWnd == 0 {
+ // Increment a metric if we are advertising an actual zero window.
+ r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+ return r.rcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 13acaf753..1f9c5cf50 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -71,6 +71,9 @@ type segment struct {
// xmitTime is the last transmit time of this segment.
xmitTime time.Time `state:".(unixTime)"`
xmitCount uint32
+
+ // acked indicates if the segment has already been SACKed.
+ acked bool
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index c55589c45..6fa8d63cd 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "sort"
"sync/atomic"
"time"
@@ -263,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
+ rc: rackControl{
+ fack: iss,
+ },
gso: ep.gso != nil,
}
@@ -1274,6 +1278,39 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
return true
}
+// Iterate the writeList and update RACK for each segment which is newly acked
+// either cumulatively or selectively. Loop through the segments which are
+// sacked, and update the RACK related variables and check for reordering.
+//
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// steps 2 and 3.
+func (s *sender) walkSACK(rcvdSeg *segment) {
+ // Sort the SACK blocks. The first block is the most recent unacked
+ // block. The following blocks can be in arbitrary order.
+ sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sort.Slice(sackBlocks, func(i, j int) bool {
+ return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
+ })
+
+ seg := s.writeList.Front()
+ for _, sb := range sackBlocks {
+ // This check excludes DSACK blocks.
+ if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
+ continue
+ }
+
+ for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
+ if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
+ seg.acked = true
+ }
+ seg = seg.Next()
+ }
+ }
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1308,6 +1345,21 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
rcvdSeg.hasNewSACKInfo = true
}
}
+
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08
+ // section-7.2
+ // * Step 2: Update RACK stats.
+ // If the ACK is not ignored as invalid, update the RACK.rtt
+ // to be the RTT sample calculated using this ACK, and
+ // continue. If this ACK or SACK was for the most recently
+ // sent packet, then record the RACK.xmit_ts timestamp and
+ // RACK.end_seq sequence implied by this ACK.
+ // * Step 3: Detect packet reordering.
+ // If the ACK selectively or cumulatively acknowledges an
+ // unacknowledged and also never retransmitted sequence below
+ // RACK.fack, then the corresponding packet has been
+ // reordered and RACK.reord is set to TRUE.
+ s.walkSACK(rcvdSeg)
s.SetPipe()
}
@@ -1365,9 +1417,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
- s.rtt.Lock()
- srtt := s.rtt.srtt
- s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1388,13 +1437,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted {
- s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ if s.ep.sackPermitted && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
}
s.writeList.Remove(seg)
- // if SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then Only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index e03f101e8..d3f92b48c 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -21,17 +21,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
+const (
+ maxPayload = 10
+ tsOptionSize = 12
+ maxTCPOptionSize = 40
+)
+
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
- const maxPayload = 10
- const tsOptionSize = 12
- const maxTCPOptionSize = 40
-
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
defer c.Cleanup()
@@ -49,7 +52,7 @@ func TestRACKUpdate(t *testing.T) {
}
if state.Sender.RACKState.RTT == 0 {
- t.Fatalf("RACK RTT failed to update when an ACK is received")
+ t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
}
})
setStackSACKPermitted(t, c, true)
@@ -69,6 +72,66 @@ func TestRACKUpdate(t *testing.T) {
bytesRead := 0
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- c.SendAck(790, bytesRead)
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
time.Sleep(200 * time.Millisecond)
}
+
+// TestRACKDetectReorder tests that RACK detects packet reordering.
+func TestRACKDetectReorder(t *testing.T) {
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ const ackNum = 2
+
+ var n int
+ ch := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ gotSeq := state.Sender.RACKState.FACK
+ wantSeq := state.Sender.SndNxt
+ // FACK should be updated to the highest ending sequence number of the
+ // segment acknowledged most recently.
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ n++
+ if n < ackNum {
+ if state.Sender.RACKState.Reord {
+ t.Fatalf("RACK reorder detected when there is no reordering")
+ }
+ return
+ }
+
+ if state.Sender.RACKState.Reord == false {
+ t.Fatalf("RACK reorder detection failed")
+ }
+ close(ch)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ data := buffer.NewView(ackNum * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < ackNum; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ start := c.IRS.Add(maxPayload + 1)
+ end := start.Add(maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-ch
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index dd810f594..a7149efd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5435,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
tests := []struct {
name string
@@ -6182,7 +6182,9 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
tsVal := uint32(rawEP.TSVal)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
@@ -6262,14 +6264,27 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
+
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- pkt := c.GetPacket()
- curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // Read loop above could generate an ACK if the window had dropped to
+ // zero and then read had opened it up.
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
+ }
+ lastACK = pkt
+ }
+ curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
// If thew new current window is close maxReceiveBufferSize then terminate
// the loop. This can happen before all iterations are done due to timing
// differences when running the test.
@@ -7326,7 +7341,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf * 2
+ remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
@@ -7341,7 +7356,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index faf51ef95..4d7847142 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -68,9 +68,9 @@ const (
// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // TestInitialSequenceNumber is the initial sequence number sent in packets that
// are sent in response to a SYN or in the initial SYN sent to the stack.
- testInitialSequenceNumber = 789
+ TestInitialSequenceNumber = 789
)
// StackAddrWithPrefix is StackAddr with its associated prefix length.
@@ -505,7 +505,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -532,7 +532,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -912,7 +912,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -1084,7 +1084,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
offset += paddingToAdd
// Send a SYN request.
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: StackPort,
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 086d0bdbc..d57ed5d79 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1397,15 +1397,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- // Never receive from a multicast address.
- if header.IsV4MulticastAddress(id.RemoteAddress) ||
- header.IsV6MulticastAddress(id.RemoteAddress) {
- e.stack.Stats().UDP.InvalidSourceAddress.Increment()
- e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- return
- }
-
if !verifyChecksum(r, hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index e6fc23258..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -45,6 +45,7 @@ const (
)
type protocol struct {
+ stack *stack.Stack
}
// Number returns the udp protocol number.
@@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a UDP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 0556ef879..b4604ba35 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -928,42 +928,6 @@ func TestReadFromMulticast(t *testing.T) {
}
}
-// TestReadFromMulticaststats checks that a discarded packet
-// that that was sent with multicast SOURCE address increments
-// the correct counters and that a regular packet does not.
-func TestReadFromMulticastStats(t *testing.T) {
- t.Helper()
- for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- c.injectPacket(flow, payload, false)
-
- var want uint64 = 0
- if flow.isReverseMulticast() {
- want = 1
- }
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
- t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
- t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -1466,6 +1430,28 @@ func TestNoChecksum(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1483,16 +1469,19 @@ func TestTTL(t *testing.T) {
if flow.isMulticast() {
wantTTL = multicastTTL
} else {
- var p stack.NetworkProtocol
+ var p stack.NetworkProtocolFactory
+ var n tcpip.NetworkProtocolNumber
if flow.isV4() {
- p = ipv4.NewProtocol(nil)
+ p = ipv4.NewProtocol
+ n = ipv4.ProtocolNumber
} else {
- p = ipv6.NewProtocol(nil)
+ p = ipv6.NewProtocol
+ n = ipv6.ProtocolNumber
}
- ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- }))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ })
+ ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil)
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1789,16 +1778,26 @@ func TestV4UnknownDestination(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4DstUnreachable),
checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ // We need to compare the included data part of the UDP packet that is in
+ // the ICMP packet with the matching original data.
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ // To work out the data size we need to simulate what the sender would
+ // have done. The wanted size is the total available minus the sum of
+ // the headers in the UDP AND ICMP packets, given that we know the test
+ // had only a minimal IP header but the ICMP sender will have allowed
+ // for a maximally sized packet header.
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+
}
- // In case of large payloads the IP packet may be truncated. Update
+ // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // Add back the two headers within the payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram.Payload()), wantLen; got != want {
@@ -2024,7 +2023,8 @@ func TestPayloadModifiedV4(t *testing.T) {
payload := newPayload()
h := unicastV4.header4Tuple(incoming)
buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
@@ -2054,7 +2054,8 @@ func TestPayloadModifiedV6(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 06fb823f6..49ab87c58 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -270,7 +270,7 @@ func RandomID(prefix string) string {
// same name, sometimes between test runs the socket does not get cleaned up
// quickly enough, causing container creation to fail.
func RandomContainerID() string {
- return RandomID("test-container-")
+ return RandomID("test-container")
}
// Copy copies file from src to dst.
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 27279b409..9b1e7a085 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -21,7 +21,6 @@ import (
"io"
"strconv"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
@@ -184,51 +183,6 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) {
return n, err
}
-// CopyObjectOut copies a fixed-size value or slice of fixed-size values from
-// src to the memory mapped at addr in uio. It returns the number of bytes
-// copied.
-//
-// CopyObjectOut must use reflection to encode src; performance-sensitive
-// clients should do encoding manually and use uio.CopyOut directly.
-//
-// Preconditions: Same as IO.CopyOut.
-func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) {
- w := &IOReadWriter{
- Ctx: ctx,
- IO: uio,
- Addr: addr,
- Opts: opts,
- }
- // Allocate a byte slice the size of the object being marshaled. This
- // adds an extra reflection call, but avoids needing to grow the slice
- // during encoding, which can result in many heap-allocated slices.
- b := make([]byte, 0, binary.Size(src))
- return w.Write(binary.Marshal(b, ByteOrder, src))
-}
-
-// CopyObjectIn copies a fixed-size value or slice of fixed-size values from
-// the memory mapped at addr in uio to dst. It returns the number of bytes
-// copied.
-//
-// CopyObjectIn must use reflection to decode dst; performance-sensitive
-// clients should use uio.CopyIn directly and do decoding manually.
-//
-// Preconditions: Same as IO.CopyIn.
-func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) {
- r := &IOReadWriter{
- Ctx: ctx,
- IO: uio,
- Addr: addr,
- Opts: opts,
- }
- buf := make([]byte, binary.Size(dst))
- if _, err := io.ReadFull(r, buf); err != nil {
- return 0, err
- }
- binary.Unmarshal(buf, ByteOrder, dst)
- return int(r.Addr - addr), nil
-}
-
// CopyStringIn tuning parameters, defined outside that function for tests.
const (
copyStringIncrement = 64
diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
index bf3c5df2b..da60b0cc7 100644
--- a/pkg/usermem/usermem_test.go
+++ b/pkg/usermem/usermem_test.go
@@ -16,7 +16,6 @@ package usermem
import (
"bytes"
- "encoding/binary"
"fmt"
"reflect"
"strings"
@@ -174,23 +173,6 @@ type testStruct struct {
Uint64 uint64
}
-func TestCopyObject(t *testing.T) {
- wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8}
- wantN := binary.Size(wantObj)
- b := &BytesIO{make([]byte, wantN)}
- ctx := newContext()
- if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil {
- t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- var gotObj testStruct
- if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil {
- t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if gotObj != wantObj {
- t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj)
- }
-}
-
func TestCopyStringInShort(t *testing.T) {
// Tests for string length <= copyStringIncrement.
want := strings.Repeat("A", copyStringIncrement-2)