summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/ioctl.go20
-rw-r--r--pkg/abi/linux/membarrier.go34
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go13
-rw-r--r--pkg/abi/linux/seccomp.go19
-rw-r--r--pkg/abi/linux/signalfd.go4
-rw-r--r--pkg/merkletree/merkletree.go256
-rw-r--r--pkg/merkletree/merkletree_test.go209
-rw-r--r--pkg/seccomp/BUILD2
-rw-r--r--pkg/seccomp/seccomp_test.go246
-rw-r--r--pkg/sentry/arch/arch_aarch64.go2
-rw-r--r--pkg/sentry/arch/registers.proto1
-rw-r--r--pkg/sentry/control/proc.go4
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go14
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/net_tun.go52
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go17
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go2
-rw-r--r--pkg/sentry/fs/user/path.go1
-rw-r--r--pkg/sentry/fs/user/user.go1
-rw-r--r--pkg/sentry/fsbridge/vfs.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go91
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go5
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go3
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go14
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go32
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go46
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go12
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go1
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go19
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go66
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go12
-rw-r--r--pkg/sentry/fsimpl/host/host.go6
-rw-r--r--pkg/sentry/fsimpl/host/socket.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD12
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go192
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go235
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go173
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go87
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go40
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go3
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go5
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go22
-rw-r--r--pkg/sentry/fsimpl/proc/task.go86
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go67
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go34
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go40
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go56
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go22
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go118
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go11
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go5
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go8
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go51
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go37
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs_test.go1
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD21
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go216
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go238
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go491
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/membarrier.go90
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kcov.go40
-rw-r--r--pkg/sentry/kernel/kernel.go46
-rw-r--r--pkg/sentry/kernel/seccomp.go46
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/kernel/task.go4
-rw-r--r--pkg/sentry/kernel/threads.go7
-rw-r--r--pkg/sentry/kernel/vdso.go19
-rw-r--r--pkg/sentry/loader/vdso.go6
-rw-r--r--pkg/sentry/memmap/memmap.go4
-rw-r--r--pkg/sentry/mm/mm.go14
-rw-r--r--pkg/sentry/mm/pma.go24
-rw-r--r--pkg/sentry/mm/syscalls.go25
-rw-r--r--pkg/sentry/mm/vma.go10
-rw-r--r--pkg/sentry/platform/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/BUILD12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go31
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go6
-rw-r--r--pkg/sentry/platform/kvm/context.go5
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go13
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go11
-rw-r--r--pkg/sentry/platform/kvm/kvm.go13
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_test.go11
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go21
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go29
-rw-r--r--pkg/sentry/platform/kvm/machine.go52
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go186
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go115
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go26
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go26
-rw-r--r--pkg/sentry/platform/platform.go51
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go1
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go38
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.s204
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s30
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/kernel.go22
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go64
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s47
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go11
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go4
-rw-r--r--pkg/sentry/platform/ring0/x86.go40
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go4
-rw-r--r--pkg/sentry/socket/netfilter/targets.go107
-rw-r--r--pkg/sentry/socket/netstack/netstack.go33
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go1
-rw-r--r--pkg/sentry/syscalls/linux/sys_membarrier.go103
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go6
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/anonfs.go5
-rw-r--r--pkg/sentry/vfs/file_description.go24
-rw-r--r--pkg/sentry/vfs/filesystem_type.go3
-rw-r--r--pkg/sentry/vfs/mount.go11
-rw-r--r--pkg/sentry/vfs/vfs.go10
-rw-r--r--pkg/tcpip/buffer/view.go18
-rw-r--r--pkg/tcpip/checker/checker.go202
-rw-r--r--pkg/tcpip/header/eth.go16
-rw-r--r--pkg/tcpip/header/eth_test.go47
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/icmpv6.go21
-rw-r--r--pkg/tcpip/header/ipv4.go71
-rw-r--r--pkg/tcpip/header/ipv6.go14
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go113
-rw-r--r--pkg/tcpip/header/ipversion_test.go2
-rw-r--r--pkg/tcpip/header/parse/parse.go2
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe.go124
-rw-r--r--pkg/tcpip/link/tun/device.go42
-rw-r--r--pkg/tcpip/network/arp/arp.go25
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go81
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go124
-rw-r--r--pkg/tcpip/network/ip_test.go118
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go146
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go234
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go577
-rw-r--r--pkg/tcpip/network/ipv6/BUILD2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go174
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go264
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go335
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go654
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go8
-rw-r--r--pkg/tcpip/network/testutil/BUILD1
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go116
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go11
-rw-r--r--pkg/tcpip/stack/conntrack.go46
-rw-r--r--pkg/tcpip/stack/forwarding_test.go (renamed from pkg/tcpip/stack/forwarder_test.go)12
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go70
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go15
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go66
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go5
-rw-r--r--pkg/tcpip/stack/nic.go169
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go (renamed from pkg/tcpip/stack/forwarder.go)60
-rw-r--r--pkg/tcpip/stack/registration.go54
-rw-r--r--pkg/tcpip/stack/route.go71
-rw-r--r--pkg/tcpip/stack/stack.go48
-rw-r--r--pkg/tcpip/stack/stack_test.go125
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go378
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go219
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/rack.go54
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go54
-rw-r--r--pkg/tcpip/transport/tcp/segment.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go62
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go75
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go20
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go30
-rw-r--r--pkg/test/testutil/testutil.go2
235 files changed, 8081 insertions, 2767 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index cdcaa8c73..4a26e28de 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -38,6 +38,7 @@ go_library(
"ipc.go",
"limits.go",
"linux.go",
+ "membarrier.go",
"mm.go",
"netdevice.go",
"netfilter.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index dc9ac7e7c..7df02dd6d 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,9 +121,27 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
- FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_MEASURE_VERITY = 3221513862
)
+// DigestMetadata is a helper struct for VerityDigest.
+//
+// +marshal
+type DigestMetadata struct {
+ DigestAlgorithm uint16
+ DigestSize uint16
+}
+
+// SizeOfDigestMetadata is the size of struct DigestMetadata.
+const SizeOfDigestMetadata = 4
+
+// VerityDigest is struct from uapi/linux/fsverity.h.
+type VerityDigest struct {
+ Metadata DigestMetadata
+ Digest []byte
+}
+
// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
func IOC(dir, typ, nr, size uint32) uint32 {
return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
diff --git a/pkg/abi/linux/membarrier.go b/pkg/abi/linux/membarrier.go
new file mode 100644
index 000000000..4f6021a1d
--- /dev/null
+++ b/pkg/abi/linux/membarrier.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// membarrier(2) commands, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_QUERY = 0
+ MEMBARRIER_CMD_GLOBAL = (1 << 0)
+ MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1)
+ MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 5)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 6)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ = (1 << 7)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ = (1 << 8)
+)
+
+// membarrier(2) flags, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_FLAG_CPU = (1 << 0)
+)
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index a137940b6..6d31eb5e3 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -321,3 +321,16 @@ const (
// Enable all flags.
IP6T_INV_MASK = 0x7F
)
+
+// NFNATRange corresponds to struct nf_nat_range in
+// include/uapi/linux/netfilter/nf_nat.h.
+type NFNATRange struct {
+ Flags uint32
+ MinAddr Inet6Addr
+ MaxAddr Inet6Addr
+ MinProto uint16 // Network byte order.
+ MaxProto uint16 // Network byte order.
+}
+
+// SizeOfNFNATRange is the size of NFNATRange.
+const SizeOfNFNATRange = 40
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index b07cafe12..5be3f10f9 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -83,3 +83,22 @@ type SockFprog struct {
pad [6]byte
Filter *BPFInstruction
}
+
+// SeccompData is equivalent to struct seccomp_data, which contains the data
+// passed to seccomp-bpf filters.
+//
+// +marshal
+type SeccompData struct {
+ // Nr is the system call number.
+ Nr int32
+
+ // Arch is an AUDIT_ARCH_* value indicating the system call convention.
+ Arch uint32
+
+ // InstructionPointer is the value of the instruction pointer at the time
+ // of the system call.
+ InstructionPointer uint64
+
+ // Args contains the first 6 system call arguments.
+ Args [6]uint64
+}
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
index 85fad9956..468c6a387 100644
--- a/pkg/abi/linux/signalfd.go
+++ b/pkg/abi/linux/signalfd.go
@@ -23,6 +23,8 @@ const (
)
// SignalfdSiginfo is the siginfo encoding for signalfds.
+//
+// +marshal
type SignalfdSiginfo struct {
Signo uint32
Errno int32
@@ -41,5 +43,5 @@ type SignalfdSiginfo struct {
STime uint64
Addr uint64
AddrLSB uint16
- _ [48]uint8
+ _ [48]uint8 `marshal:"unaligned"`
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 4b4f9bd52..d8227b8bd 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -41,7 +41,7 @@ type Layout struct {
blockSize int64
// digestSize is the size of a generated hash.
digestSize int64
- // levelOffset contains the offset of the begnning of each level in
+ // levelOffset contains the offset of the beginning of each level in
// bytes. The number of levels in the tree is the length of the slice.
// The leaf nodes (level 0) contain hashes of blocks of the input data.
// Each level N contains hashes of the blocks in level N-1. The highest
@@ -123,48 +123,73 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
return layout.levelOffset[level] + index*layout.blockSize
}
-// Generate constructs a Merkle tree for the contents of data. The output is
-// written to treeWriter. The treeReader should be able to read the tree after
-// it has been written. That is, treeWriter and treeReader should point to the
-// same underlying data but have separate cursors.
-// Generate will modify the cursor for data, but always restores it to its
-// original position upon exit. The cursor for tree is modified and not
-// restored.
-func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) {
- layout := InitLayout(dataSize, dataAndTreeInSameFile)
+// VerityDescriptor is a struct that is serialized and hashed to get a file's
+// root hash, which contains the root hash of the raw content and the file's
+// meatadata.
+type VerityDescriptor struct {
+ Name string
+ Mode uint32
+ UID uint32
+ GID uint32
+ RootHash []byte
+}
- numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize
+func (d *VerityDescriptor) String() string {
+ return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+}
+
+// verify generates a hash from d, and compares it with expected.
+func (d *VerityDescriptor) verify(expected []byte) error {
+ h := sha256.Sum256([]byte(d.String()))
+ if !bytes.Equal(h[:], expected) {
+ return fmt.Errorf("unexpected root hash")
+ }
+ return nil
+}
+
+// GenerateParams contains the parameters used to generate a Merkle tree.
+type GenerateParams struct {
+ // File is a reader of the file to be hashed.
+ File io.ReaderAt
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // TreeReader is a reader for the Merkle tree.
+ TreeReader io.ReaderAt
+ // TreeWriter is a writer for the Merkle tree.
+ TreeWriter io.Writer
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
+// Generate constructs a Merkle tree for the contents of params.File. The
+// output is written to params.TreeWriter.
+//
+// Generate returns a hash of a VerityDescriptor, which contains the file
+// metadata and the hash from file content.
+func Generate(params *GenerateParams) ([]byte, error) {
+ layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
+
+ numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
// If the data is in the same file as the tree, zero pad the last data
// block.
- bytesInLastBlock := dataSize % layout.blockSize
- if dataAndTreeInSameFile && bytesInLastBlock != 0 {
+ bytesInLastBlock := params.Size % layout.blockSize
+ if params.DataAndTreeInSameFile && bytesInLastBlock != 0 {
zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock)
- if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF {
- return nil, err
- }
- if _, err := treeWriter.Write(zeroBuf); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf); err != nil {
return nil, err
}
}
- // Store the current offset, so we can set it back once verification
- // finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
- if err != nil {
- return nil, err
- }
- defer data.Seek(origOffset, io.SeekStart)
-
- // Read from the beginning of both data and treeReader.
- if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF {
- return nil, err
- }
-
- if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF {
- return nil, err
- }
-
var root []byte
for level := 0; level < layout.numLevels(); level++ {
for i := int64(0); i < numBlocks; i++ {
@@ -176,11 +201,11 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
if level == 0 {
// Read data block from the target file since level 0 includes hashes
// of blocks in the input data.
- n, err = data.Read(buf)
+ n, err = params.File.ReadAt(buf, i*layout.blockSize)
} else {
// Read data block from the tree file since levels higher than 0 are
// hashing the lower level hashes.
- n, err = treeReader.Read(buf)
+ n, err = params.TreeReader.ReadAt(buf, layout.blockOffset(level-1, i))
}
// err is populated as long as the bytes read is smaller than the buffer
@@ -200,7 +225,7 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
}
// Write the generated hash to the end of the tree file.
- if _, err = treeWriter.Write(digest[:]); err != nil {
+ if _, err = params.TreeWriter.Write(digest[:]); err != nil {
return nil, err
}
}
@@ -208,46 +233,95 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
// remaining of the last block. But no need to do so for root.
if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 {
zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize)
- if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf[:]); err != nil {
return nil, err
}
}
numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock()
}
- return root, nil
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ ret := sha256.Sum256([]byte(descriptor.String()))
+ return ret[:], nil
+}
+
+// VerifyParams contains the params used to verify a portion of a file against
+// a Merkle tree.
+type VerifyParams struct {
+ // Out will be filled with verified data.
+ Out io.Writer
+ // File is a handler on the file to be verified.
+ File io.ReaderAt
+ // tree is a handler on the Merkle tree used to verify file.
+ Tree io.ReaderAt
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // ReadOffset is the offset of the data range to be verified.
+ ReadOffset int64
+ // ReadSize is the size of the data range to be verified.
+ ReadSize int64
+ // Expected is a trusted hash for the file. It is compared with the
+ // calculated root hash to verify the content.
+ Expected []byte
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
+// verifyMetadata verifies the metadata by hashing a descriptor that contains
+// the metadata and compare the generated hash with expected.
+//
+// For verifyMetadata, params.data is not needed. It only accesses params.tree
+// for the raw root hash.
+func verifyMetadata(params *VerifyParams, layout *Layout) error {
+ root := make([]byte, layout.digestSize)
+ if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ return descriptor.verify(params.Expected)
}
// Verify verifies the content read from data with offset. The content is
// verified against tree. If content spans across multiple blocks, each block is
// verified. Verification fails if the hash of the data does not match the tree
-// at any level, or if the final root hash does not match expectedRoot.
-// Once the data is verified, it will be written using w.
-// Verify will modify the cursor for data, but always restores it to its
-// original position upon exit. The cursor for tree is modified and not
-// restored.
-func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) {
- if readSize <= 0 {
- return 0, fmt.Errorf("Unexpected read size: %d", readSize)
+// at any level, or if the final root hash does not match expected.
+// Once the data is verified, it will be written using params.Out.
+//
+// Verify checks for both target file content and metadata. If readSize is 0,
+// only metadata is checked.
+func Verify(params *VerifyParams) (int64, error) {
+ if params.ReadSize < 0 {
+ return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
+ }
+ layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ if params.ReadSize == 0 {
+ return 0, verifyMetadata(params, &layout)
}
- layout := InitLayout(int64(dataSize), dataAndTreeInSameFile)
// Calculate the index of blocks that includes the target range in input
// data.
- firstDataBlock := readOffset / layout.blockSize
- lastDataBlock := (readOffset + readSize - 1) / layout.blockSize
-
- // Store the current offset, so we can set it back once verification
- // finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
- if err != nil {
- return 0, fmt.Errorf("Find current data offset failed: %v", err)
- }
- defer data.Seek(origOffset, io.SeekStart)
-
- // Move to the first block that contains target data.
- if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
- return 0, fmt.Errorf("Seek to datablock start failed: %v", err)
- }
+ firstDataBlock := params.ReadOffset / layout.blockSize
+ lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize
buf := make([]byte, layout.blockSize)
var readErr error
@@ -255,7 +329,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
for i := firstDataBlock; i <= lastDataBlock; i++ {
// Read a block that includes all or part of target range in
// input data.
- bytesRead, err := data.Read(buf)
+ bytesRead, err := params.File.ReadAt(buf, i*layout.blockSize)
readErr = err
// If at the end of input data and all previous blocks are
// verified, return the verified input data and EOF.
@@ -263,7 +337,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
break
}
if readErr != nil && readErr != io.EOF {
- return 0, fmt.Errorf("Read from data failed: %v", err)
+ return 0, fmt.Errorf("read from data failed: %w", err)
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
@@ -274,22 +348,29 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
buf[j] = 0
}
}
- if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ }
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil {
return 0, err
}
+
// startOff is the beginning of the read range within the
// current data block. Note that for all blocks other than the
// first, startOff should be 0.
startOff := int64(0)
if i == firstDataBlock {
- startOff = readOffset % layout.blockSize
+ startOff = params.ReadOffset % layout.blockSize
}
// endOff is the end of the read range within the current data
// block. Note that for all blocks other than the last, endOff
// should be the block size.
endOff := layout.blockSize
if i == lastDataBlock {
- endOff = (readOffset+readSize-1)%layout.blockSize + 1
+ endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1
}
// If the provided size exceeds the end of input data, we should
// only copy the parts in buf that's part of input data.
@@ -299,7 +380,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
if endOff > int64(bytesRead) {
endOff = int64(bytesRead)
}
- n, err := w.Write(buf[startOff:endOff])
+ n, err := params.Out.Write(buf[startOff:endOff])
if err != nil {
return total, err
}
@@ -313,9 +394,8 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
// original data. The block is verified through each level of the tree. It
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
-// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to
-// maintain the cursor if intended.
-func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error {
+// expected.
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -332,41 +412,27 @@ func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex
// Read a block in previous level that contains the
// hash we just generated, and generate a next level
// hash from it.
- if _, err := tree.Seek(layout.blockOffset(level-1, blockIndex), io.SeekStart); err != nil {
- return err
- }
- if _, err := tree.Read(treeBlock); err != nil {
+ if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil {
return err
}
digestArray := sha256.Sum256(treeBlock)
digest = digestArray[:]
}
- // Move to stored hash for the current block, read the digest
- // and store in expectedDigest.
- if _, err := tree.Seek(layout.digestOffset(level, blockIndex), io.SeekStart); err != nil {
- return err
- }
- if _, err := tree.Read(expectedDigest); err != nil {
+ // Read the digest for the current block and store in
+ // expectedDigest.
+ if _, err := tree.ReadAt(expectedDigest, layout.digestOffset(level, blockIndex)); err != nil {
return err
}
if !bytes.Equal(digest, expectedDigest) {
- return fmt.Errorf("Verification failed")
- }
-
- // If this is the root layer, no need to generate next level
- // hash.
- if level == layout.rootLevel() {
- break
+ return fmt.Errorf("verification failed")
}
blockIndex = blockIndex / layout.hashesPerBlock()
}
- // Verification for the tree succeeded. Now compare the root hash in the
- // tree with expectedRoot.
- if !bytes.Equal(digest[:], expectedRoot) {
- return fmt.Errorf("Verification failed")
- }
- return nil
+ // Verification for the tree succeeded. Now hash the descriptor with
+ // the root hash and compare it with expected.
+ descriptor.RootHash = digest
+ return descriptor.verify(expected)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index daaca759a..e1350ebda 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -84,6 +84,13 @@ func TestLayout(t *testing.T) {
}
}
+const (
+ defaultName = "merkle_test"
+ defaultMode = 0644
+ defaultUID = 0
+ defaultGID = 0
+)
+
// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike
// bytes.Buffer, it keeps the whole buffer during read so that it can be reused.
type bytesReadWriter struct {
@@ -99,58 +106,36 @@ func (brw *bytesReadWriter) Write(p []byte) (int, error) {
return len(p), nil
}
-func (brw *bytesReadWriter) Read(p []byte) (int, error) {
- if brw.readPos >= len(brw.bytes) {
- return 0, io.EOF
- }
- bytesRead := copy(p, brw.bytes[brw.readPos:])
- brw.readPos += bytesRead
- if bytesRead < len(p) {
+func (brw *bytesReadWriter) ReadAt(p []byte, off int64) (int, error) {
+ bytesRead := copy(p, brw.bytes[off:])
+ if bytesRead == 0 {
return bytesRead, io.EOF
}
return bytesRead, nil
}
-func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) {
- off := offset
- if whence == io.SeekCurrent {
- off += int64(brw.readPos)
- }
- if whence == io.SeekEnd {
- off += int64(len(brw.bytes))
- }
- if off < 0 {
- panic("seek with negative offset")
- }
- if off >= int64(len(brw.bytes)) {
- return 0, io.EOF
- }
- brw.readPos = int(off)
- return off, nil
-}
-
func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
data []byte
- expectedRoot []byte
+ expectedHash []byte
}{
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
},
{
data: []byte{'a'},
- expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90},
+ expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
},
}
@@ -158,22 +143,31 @@ func TestGenerate(t *testing.T) {
t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ params := GenerateParams{
+ Size: int64(len(tc.data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(tc.data)
- root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ params.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ params.File = &bytesReadWriter{
bytes: tc.data,
- }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ hash, err := Generate(&params)
if err != nil {
t.Fatalf("Got err: %v, want nil", err)
}
- if !bytes.Equal(root, tc.expectedRoot) {
- t.Errorf("Got root: %v, want %v", root, tc.expectedRoot)
+ if !bytes.Equal(hash, tc.expectedHash) {
+ t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash)
}
}
})
@@ -194,6 +188,10 @@ func TestVerify(t *testing.T) {
// modified byte falls in verification range, Verify should
// fail, otherwise Verify should still succeed.
modifyByte int64
+ modifyName bool
+ modifyMode bool
+ modifyUID bool
+ modifyGID bool
shouldSucceed bool
}{
// Verify range start outside the data range should fail.
@@ -222,12 +220,48 @@ func TestVerify(t *testing.T) {
modifyByte: 0,
shouldSucceed: false,
},
- // Invalid verify range (0 size) should fail.
+ // 0 verify size should only verify metadata.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ shouldSucceed: true,
+ },
+ // Modified name should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyName: true,
+ shouldSucceed: false,
+ },
+ // Modified mode should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyMode: true,
+ shouldSucceed: false,
+ },
+ // Modified UID should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyUID: true,
+ shouldSucceed: false,
+ },
+ // Modified GID should fail verification.
{
dataSize: usermem.PageSize,
verifyStart: 0,
verifySize: 0,
modifyByte: 0,
+ modifyGID: true,
shouldSucceed: false,
},
// The test cases below use a block-aligned verify range.
@@ -316,16 +350,25 @@ func TestVerify(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */)
+ }
}
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -333,8 +376,34 @@ func TestVerify(t *testing.T) {
// Flip a bit in data and checks Verify results.
var buf bytes.Buffer
data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
+ }
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
+ }
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
+ }
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
if tc.shouldSucceed {
- n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed when expected to succeed: %v", err)
}
@@ -348,7 +417,7 @@ func TestVerify(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
} else {
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&verifyParams); err == nil {
t.Errorf("Verification succeeded when expected to fail")
}
}
@@ -368,16 +437,26 @@ func TestVerifyRandom(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ hash, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -387,9 +466,24 @@ func TestVerifyRandom(t *testing.T) {
size := rand.Int63n(dataSize) + 1
var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: start,
+ ReadSize: size,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
// Checks that the random portion of data from the original data is
// verified successfully.
- n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed for correct data: %v", err)
}
@@ -406,13 +500,22 @@ func TestVerifyRandom(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
+ // Verify that modified metadata should fail verification.
buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
+
// Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
randBytePos := rand.Int63n(size)
data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
- t.Errorf("Verification succeeded for modified data")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
}
}
}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index bdef7762c..e828894b0 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -49,7 +49,7 @@ go_test(
library = ":seccomp",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bpf",
+ "//pkg/usermem",
],
)
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 23f30678d..e1444d18b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -28,17 +28,10 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-type seccompData struct {
- nr uint32
- arch uint32
- instructionPointer uint64
- args [6]uint64
-}
-
// newVictim makes a victim binary.
func newVictim() (string, error) {
f, err := ioutil.TempFile("", "victim")
@@ -58,9 +51,14 @@ func newVictim() (string, error) {
return path, nil
}
-// asInput converts a seccompData to a bpf.Input.
-func (d *seccompData) asInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+// dataAsInput converts a linux.SeccompData to a bpf.Input.
+func dataAsInput(d *linux.SeccompData) bpf.Input {
+ buf := make([]byte, d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ Order: usermem.ByteOrder,
+ }
}
func TestBasic(t *testing.T) {
@@ -69,7 +67,7 @@ func TestBasic(t *testing.T) {
desc string
// data is the input data.
- data seccompData
+ data linux.SeccompData
// want is the expected return value of the BPF program.
want linux.BPFAction
@@ -95,12 +93,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "syscall allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "syscall disallowed",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -131,22 +129,22 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1a)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (1b)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "syscall 1 matched 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "no match",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -168,42 +166,42 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (3)",
- data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (5)",
- data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed (0)",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (2)",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (4)",
- data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (6)",
- data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (100)",
- data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -223,7 +221,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arch (123)",
- data: seccompData{nr: 1, arch: 123},
+ data: linux.SeccompData{Nr: 1, Arch: 123},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -243,7 +241,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "action trap",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -268,12 +266,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -300,12 +298,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "match first rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "match 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -331,28 +329,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "argument allowed (all match)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "argument disallowed (one mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "argument disallowed (multiple mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -379,28 +377,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (one equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (all equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -429,27 +427,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -474,27 +472,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -522,27 +520,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -567,32 +565,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (both greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -620,27 +618,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -665,32 +663,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -718,27 +716,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -764,32 +762,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -816,51 +814,51 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (low order mandatory bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000001
- args: [6]uint64{0x1},
+ Args: [6]uint64{0x1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (low order optional bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000101
- args: [6]uint64{0x5},
+ Args: [6]uint64{0x5},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (lowest order bit not set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000010
- args: [6]uint64{0x2},
+ Args: [6]uint64{0x2},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second lowest order bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000011
- args: [6]uint64{0x3},
+ Args: [6]uint64{0x3},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (8th bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000001 00000000
- args: [6]uint64{0x100},
+ Args: [6]uint64{0x100},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -885,12 +883,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -906,7 +904,7 @@ func TestBasic(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for _, spec := range test.specs {
- got, err := bpf.Exec(p, spec.data.asInput())
+ got, err := bpf.Exec(p, dataAsInput(&spec.data))
if err != nil {
t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err)
}
@@ -947,8 +945,8 @@ func TestRandom(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for i := uint32(0); i < 200; i++ {
- data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH}
- got, err := bpf.Exec(p, data.asInput())
+ data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH}
+ got, err := bpf.Exec(p, dataAsInput(&data))
if err != nil {
t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i)
continue
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 0f433ee79..fd73751e7 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -154,6 +154,7 @@ func (s State) Proto() *rpb.Registers {
Sp: s.Regs.Sp,
Pc: s.Regs.Pc,
Pstate: s.Regs.Pstate,
+ Tls: s.Regs.TPIDR_EL0,
}
return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
}
@@ -232,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
"Sp": uintptr(s.Regs.Sp),
"Pc": uintptr(s.Regs.Pc),
"Pstate": uintptr(s.Regs.Pstate),
+ "Tls": uintptr(s.Regs.TPIDR_EL0),
}, nil
}
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 60c027aab..2727ba08a 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -83,6 +83,7 @@ message ARM64Registers {
uint64 sp = 32;
uint64 pc = 33;
uint64 pstate = 34;
+ uint64 tls = 35;
}
message Registers {
oneof arch {
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 668f47802..1d88db12f 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -183,9 +183,9 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
if initArgs.MountNamespaceVFS2 == nil {
// Set initArgs so that 'ctx' returns the namespace.
//
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
+ // Add a reference to the namespace, which is transferred to the new process.
initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ initArgs.MountNamespaceVFS2.IncRef()
}
} else {
if initArgs.MountNamespace == nil {
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 71c59287c..14a8bf9cd 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 0b701a289..655ea549b 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -16,6 +16,8 @@
package tundev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -26,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -84,7 +87,16 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := fd.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := fd.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9379a4d7b..6b7b451b8 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netstack",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 5f8c9b5a2..19ffdec47 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -15,6 +15,8 @@
package dev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -25,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -60,7 +63,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod
}
// GetFile implements fs.InodeOperations.GetFile.
-func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
}
@@ -80,12 +83,12 @@ type netTunFileOperations struct {
var _ fs.FileOperations = (*netTunFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (fops *netTunFileOperations) Release(ctx context.Context) {
- fops.device.Release(ctx)
+func (n *netTunFileOperations) Release(ctx context.Context) {
+ n.device.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
request := args[1].Uint()
data := args[2].Pointer()
@@ -109,16 +112,25 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := n.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := n.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
- copy(req.IFName[:], fops.device.Name())
+ copy(req.IFName[:], n.device.Name())
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
// there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
- flags := fops.device.Flags() | linux.IFF_NOFILTER
+ flags := n.device.Flags() | linux.IFF_NOFILTER
usermem.ByteOrder.PutUint16(req.Data[:], flags)
_, err := req.CopyOut(t, data)
@@ -130,41 +142,41 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
}
// Write implements fs.FileOperations.Write.
-func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
data := make([]byte, src.NumBytes())
if _, err := src.CopyIn(ctx, data); err != nil {
return 0, err
}
- return fops.device.Write(data)
+ return n.device.Write(data)
}
// Read implements fs.FileOperations.Read.
-func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- data, err := fops.device.Read()
+func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := n.device.Read()
if err != nil {
return 0, err
}
- n, err := dst.CopyOut(ctx, data)
- if n > 0 && n < len(data) {
+ bytesCopied, err := dst.CopyOut(ctx, data)
+ if bytesCopied > 0 && bytesCopied < len(data) {
// Not an error for partial copying. Packet truncated.
err = nil
}
- return int64(n), err
+ return int64(bytesCopied), err
}
// Readiness implements watier.Waitable.Readiness.
-func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return fops.device.Readiness(mask)
+func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return n.device.Readiness(mask)
}
// EventRegister implements watier.Waitable.EventRegister.
-func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- fops.device.EventRegister(e, mask)
+func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ n.device.EventRegister(e, mask)
}
// EventUnregister implements watier.Waitable.EventUnregister.
-func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
- fops.device.EventUnregister(e)
+func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ n.device.EventUnregister(e)
}
// isNetTunSupported returns whether /dev/net/tun device is supported for s.
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 9197aeb88..1dc409d38 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -84,7 +84,8 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRan
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
// mmap(2): bytes after EOF on the same page are zeroed; pages after EOF are
-// invalid.
+// invalid. fileSize is an upper bound on the file's size; bytes after fileSize
+// will be zeroed without calling readAt.
//
// Fill may read offsets outside of required, but will never read offsets
// outside of optional. It returns a non-nil error if any error occurs, even
@@ -94,7 +95,7 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRan
// * required.Length() > 0.
// * optional.IsSupersetOf(required).
// * required and optional must be page-aligned.
-func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, fileSize uint64, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
gap := frs.LowerBoundGap(required.Start)
for gap.Ok() && gap.Start() < required.End {
if gap.Range().Length() == 0 {
@@ -107,7 +108,21 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
fr, err := mf.AllocateAndFill(gr.Length(), kind, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
var done uint64
for !dsts.IsEmpty() {
- n, err := readAt(ctx, dsts, gr.Start+done)
+ n, err := func() (uint64, error) {
+ off := gr.Start + done
+ if off >= fileSize {
+ return 0, io.EOF
+ }
+ if off+dsts.NumBytes() > fileSize {
+ rd := fileSize - off
+ n, err := readAt(ctx, dsts.TakeFirst64(rd), off)
+ if n == rd && err == nil {
+ return n, io.EOF
+ }
+ return n, err
+ }
+ return readAt(ctx, dsts, off)
+ }()
done += n
dsts = dsts.DropFirst64(n)
if err != nil {
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 9eb6f522e..82eda3e43 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/kernel/time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -444,7 +443,7 @@ func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs.
// time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchAccessTimeLocked(now ktime.Time) {
c.attr.AccessTime = now
c.dirtyAttr.AccessTime = true
}
@@ -461,7 +460,7 @@ func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx contex
// and status change times in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now ktime.Time) {
c.attr.ModificationTime = now
c.dirtyAttr.ModificationTime = true
c.attr.StatusChangeTime = now
@@ -480,7 +479,7 @@ func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) {
// in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now ktime.Time) {
c.attr.StatusChangeTime = now
c.dirtyAttr.StatusChangeTime = true
}
@@ -645,7 +644,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
End: fs.OffsetPageEnd(int64(gapMR.End)),
}
optMR := gap.Range()
- err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
+ err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), uint64(rw.c.attr.Size), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
mem.MarkEvictable(rw.c, pgalloc.EvictableRange{optMR.Start, optMR.End})
seg, gap = rw.c.cache.Find(uint64(rw.offset))
if !seg.Ok() {
@@ -672,9 +671,6 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
-
- default:
- break
}
}
unlock()
@@ -768,9 +764,6 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
-
- default:
- break
}
}
rw.maybeGrowFile()
@@ -877,7 +870,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option
}
mf := c.mfp.MemoryFile()
- cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
+ cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), uint64(c.attr.Size), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
var ts []memmap.Translation
var translatedEnd uint64
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index 1dc75291d..fc0498f17 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -613,7 +613,7 @@ func (f *fileInodeOperations) Translate(ctx context.Context, required, optional
}
mf := f.kernel.MemoryFile()
- cerr := f.data.Fill(ctx, required, optional, mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ cerr := f.data.Fill(ctx, required, optional, uint64(f.attr.Size), mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
// Newly-allocated pages are zeroed, so we don't need to do anything.
return dsts.NumBytes(), nil
})
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index 2f5a43b84..124bc95ed 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -121,6 +121,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
root := mns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
for _, p := range paths {
if !path.IsAbs(p) {
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
index 936fd3932..1f8684dc6 100644
--- a/pkg/sentry/fs/user/user.go
+++ b/pkg/sentry/fs/user/user.go
@@ -105,6 +105,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.
const defaultHome = "/"
root := mns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
creds := auth.CredentialsFromContext(ctx)
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 323506d33..be0900030 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// remainingTraversals is not configurable in VFS2, all callers are using the
// default anyways.
func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
- vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ vfsObj := l.root.Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 48e13613a..84baaac66 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index 903135fae..d5c5aaa8c 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -37,27 +37,51 @@ const Name = "devpts"
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
-type FilesystemType struct{}
+type FilesystemType struct {
+ initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ initErr error
+
+ // fs backs all mounts of this FilesystemType. root is fs' root. fs and root
+ // are immutable.
+ fs *vfs.Filesystem
+ root *vfs.Dentry
+}
// Name implements vfs.FilesystemType.Name.
-func (FilesystemType) Name() string {
+func (*FilesystemType) Name() string {
return Name
}
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// No data allowed.
if opts.Data != "" {
return nil, nil, syserror.EINVAL
}
- fs, root, err := fstype.newFilesystem(vfsObj, creds)
- if err != nil {
- return nil, nil, err
+ fstype.initOnce.Do(func() {
+ fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ if err != nil {
+ fstype.initErr = err
+ return
+ }
+ fstype.fs = fs.VFSFilesystem()
+ fstype.root = root.VFSDentry()
+ })
+ if fstype.initErr != nil {
+ return nil, nil, fstype.initErr
+ }
+ fstype.fs.IncRef()
+ fstype.root.IncRef()
+ return fstype.fs, fstype.root, nil
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (fstype *FilesystemType) Release(ctx context.Context) {
+ if fstype.fs != nil {
+ fstype.root.DecRef(ctx)
+ fstype.fs.DecRef(ctx)
}
- return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil
}
// +stateify savable
@@ -69,7 +93,7 @@ type filesystem struct {
// newFilesystem creates a new devpts filesystem with root directory and ptmx
// master inode. It returns the filesystem and root Dentry.
-func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -87,7 +111,9 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
root.EnableLeakCheck()
- root.dentry.Init(root)
+
+ var rootD kernfs.Dentry
+ rootD.Init(&fs.Filesystem, root)
// Construct the pts master inode and dentry. Linux always uses inode
// id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
@@ -95,15 +121,14 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root: root,
}
master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
- master.dentry.Init(master)
// Add the master as a child of the root.
- links := root.OrderedChildren.Populate(&root.dentry, map[string]*kernfs.Dentry{
- "ptmx": &master.dentry,
+ links := root.OrderedChildren.Populate(map[string]kernfs.Inode{
+ "ptmx": master,
})
root.IncLinks(links)
- return fs, &root.dentry, nil
+ return fs, &rootD, nil
}
// Release implements vfs.FilesystemImpl.Release.
@@ -117,24 +142,19 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type rootInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
kernfs.OrderedChildren
rootInodeRefs
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// master is the master pty inode. Immutable.
master *masterInode
- // root is the root directory inode for this filesystem. Immutable.
- root *rootInode
-
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
@@ -173,21 +193,24 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
- replica.dentry.Init(replica)
i.replicas[idx] = replica
return t, nil
}
// masterClose is called when the master end of t is closed.
-func (i *rootInode) masterClose(t *Terminal) {
+func (i *rootInode) masterClose(ctx context.Context, t *Terminal) {
i.mu.Lock()
defer i.mu.Unlock()
// Sanity check that replica with idx exists.
- if _, ok := i.replicas[t.n]; !ok {
+ ri, ok := i.replicas[t.n]
+ if !ok {
panic(fmt.Sprintf("pty with index %d does not exist", t.n))
}
+
+ // Drop the ref on replica inode taken during rootInode.allocateTerminal.
+ ri.DecRef(ctx)
delete(i.replicas, t.n)
}
@@ -203,16 +226,22 @@ func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.D
}
// Lookup implements kernfs.Inode.Lookup.
-func (i *rootInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
+ // Check if a static entry was looked up.
+ if d, err := i.OrderedChildren.Lookup(ctx, name); err == nil {
+ return d, nil
+ }
+
+ // Not a static entry.
idx, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
}
i.mu.Lock()
defer i.mu.Unlock()
- if si, ok := i.replicas[uint32(idx)]; ok {
- si.dentry.IncRef()
- return &si.dentry, nil
+ if ri, ok := i.replicas[uint32(idx)]; ok {
+ ri.IncRef() // This ref is passed to the dentry upon creation via Init.
+ return ri, nil
}
return nil, syserror.ENOENT
@@ -243,8 +272,8 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback,
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *rootInode) DecRef(context.Context) {
- i.rootInodeRefs.DecRef(i.Destroy)
+func (i *rootInode) DecRef(ctx context.Context) {
+ i.rootInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 69c2fe951..fda30fb93 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -42,9 +42,6 @@ type masterInode struct {
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// root is the devpts root inode.
root *rootInode
}
@@ -103,7 +100,7 @@ var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
func (mfd *masterFileDescription) Release(ctx context.Context) {
- mfd.inode.root.masterClose(mfd.t)
+ mfd.inode.root.masterClose(ctx, mfd.t)
}
// EventRegister implements waiter.Waitable.EventRegister.
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
index 6515c5536..70c68cf0a 100644
--- a/pkg/sentry/fsimpl/devpts/replica.go
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -41,9 +41,6 @@ type replicaInode struct {
locks vfs.FileLocks
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
// root is the devpts root inode.
root *rootInode
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index 6d1753080..e6fe0fc0d 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -71,6 +71,15 @@ func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virtua
return fst.fs, fst.root, nil
}
+// Release implements vfs.FilesystemType.Release.
+func (fst *FilesystemType) Release(ctx context.Context) {
+ if fst.fs != nil {
+ // Release the original reference obtained when creating the filesystem.
+ fst.root.DecRef(ctx)
+ fst.fs.DecRef(ctx)
+ }
+}
+
// Accessor allows devices to create device special files in devtmpfs.
type Accessor struct {
vfsObj *vfs.VirtualFilesystem
@@ -86,10 +95,13 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth
if err != nil {
return nil, err
}
+ // Pass a reference on root to the Accessor.
+ root := mntns.Root()
+ root.IncRef()
return &Accessor{
vfsObj: vfsObj,
mntns: mntns,
- root: mntns.Root(),
+ root: root,
creds: creds,
}, nil
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index 3a38b8bb4..e058eda7a 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -53,6 +53,7 @@ func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.Virtu
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
root := mntns.Root()
+ root.IncRef()
devpop := vfs.PathOperation{
Root: root,
Start: root,
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index abc610ef3..7b1eec3da 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/fd",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
@@ -86,9 +88,9 @@ go_test(
library = ":ext",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index c349b886e..2ee7cc7ac 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -70,6 +70,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
}
root := mntns.Root()
+ root.IncRef()
tearDown := func() {
root.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index 8bb104ff0..1165234f9 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -18,7 +18,7 @@ import (
"io"
"math"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,19 +34,19 @@ type blockMapFile struct {
// directBlks are the direct blocks numbers. The physical blocks pointed by
// these holds file data. Contains file blocks 0 to 11.
- directBlks [numDirectBlks]uint32
+ directBlks [numDirectBlks]primitive.Uint32
// indirectBlk is the physical block which contains (blkSize/4) direct block
// numbers (as uint32 integers).
- indirectBlk uint32
+ indirectBlk primitive.Uint32
// doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
// block numbers (as uint32 integers).
- doubleIndirectBlk uint32
+ doubleIndirectBlk primitive.Uint32
// tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
// indirect block numbers (as uint32 integers).
- tripleIndirectBlk uint32
+ tripleIndirectBlk primitive.Uint32
// coverage at (i)th index indicates the amount of file data a node at
// height (i) covers. Height 0 is the direct block.
@@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
}
blkMap := file.regFile.inode.diskInode.Data()
- binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
- binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ for i := 0; i < numDirectBlks; i++ {
+ file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4])
+ }
+ file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4])
+ file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4])
+ file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4])
return file, nil
}
@@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
switch {
case offset < dirBlksEnd:
// Direct block.
- curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:])
case offset < indirBlkEnd:
// Indirect block.
- curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:])
case offset < doubIndirBlkEnd:
// Doubly indirect block.
- curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:])
default:
// Triply indirect block.
- curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:])
}
read += curR
@@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
read := 0
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
- var childPhyBlk uint32
+ var childPhyBlk primitive.Uint32
err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
- n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:])
read += n
if err != nil {
return read, err
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 6fa84e7aa..ed98b482e 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -20,7 +20,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
var fileData []byte
blkNums := newBlkNumGen()
- var data []byte
+ off := 0
+ data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes())
// Write the direct blocks.
for i := 0; i < numDirectBlks; i++ {
- curBlkNum := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(data[off:])
+ off += curBlkNum.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...)
}
// Write to indirect block.
- indirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
-
- // Write to indirect block.
- doublyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
-
- // Write to indirect block.
- triplyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+ indirectBlk := primitive.Uint32(blkNums.next())
+ indirectBlk.MarshalBytes(data[off:])
+ off += indirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...)
+
+ // Write to double indirect block.
+ doublyIndirectBlk := primitive.Uint32(blkNums.next())
+ doublyIndirectBlk.MarshalBytes(data[off:])
+ off += doublyIndirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...)
+
+ // Write to triple indirect block.
+ triplyIndirectBlk := primitive.Uint32(blkNums.next())
+ triplyIndirectBlk.MarshalBytes(data[off:])
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...)
args := inodeArgs{
fs: &filesystem{
@@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN
var fileData []byte
for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
- curBlkNum := blkNums.next()
- copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
- fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(disk[off : off+4])
+ fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...)
}
return fileData
}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 452450d82..0ad79b381 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -16,7 +16,6 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -100,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
} else {
curDirent.diskDirent = &disklayout.DirentOld{}
}
- binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+ curDirent.diskDirent.UnmarshalBytes(buf)
if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
// Inode number and name length fields being set to 0 is used to indicate
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 9bd9c76c0..d98a05dd8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -22,10 +22,11 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/marshal",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
index ad6f4fef8..0d56ae9da 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// BlockGroup represents a Linux ext block group descriptor. An ext file system
// is split into a series of block groups. This provides an access layer to
// information needed to access and use a block group.
@@ -30,6 +34,8 @@ package disklayout
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors.
type BlockGroup interface {
+ marshal.Marshallable
+
// InodeTable returns the absolute block number of the block containing the
// inode table. This points to an array of Inode structs. Inode tables are
// statically allocated at mkfs time. The superblock records the number of
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
index 3e16c76db..a35fa22a0 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
@@ -17,6 +17,8 @@ package disklayout
// BlockGroup32Bit emulates the first half of struct ext4_group_desc in
// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and
// 32-bit ext4 filesystems. It implements BlockGroup interface.
+//
+// +marshal
type BlockGroup32Bit struct {
BlockBitmapLo uint32
InodeBitmapLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
index 9a809197a..d54d1d345 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
@@ -18,6 +18,8 @@ package disklayout
// It is the block group descriptor struct for 64-bit ext4 filesystems.
// It implements BlockGroup interface. It is an extension of the 32-bit
// version of BlockGroup.
+//
+// +marshal
type BlockGroup64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
index 0ef4294c0..e4ce484e4 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
@@ -21,6 +21,8 @@ import (
// TestBlockGroupSize tests that the block group descriptor structs are of the
// correct size.
func TestBlockGroupSize(t *testing.T) {
- assertSize(t, BlockGroup32Bit{}, 32)
- assertSize(t, BlockGroup64Bit{}, 64)
+ var bgSmall BlockGroup32Bit
+ assertSize(t, &bgSmall, 32)
+ var bgBig BlockGroup64Bit
+ assertSize(t, &bgBig, 64)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
index 417b6cf65..568c8cb4c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -15,6 +15,7 @@
package disklayout
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -51,6 +52,8 @@ var (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories.
type Dirent interface {
+ marshal.Marshallable
+
// Inode returns the absolute inode number of the underlying inode.
// Inode number 0 signifies an unused dirent.
Inode() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
index 29ae4a5c2..51f9c2946 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
@@ -29,12 +29,14 @@ import (
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentNew struct {
InodeNumber uint32
RecordLength uint16
NameLength uint8
FileTypeRaw uint8
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentNew implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
index 6fff12a6e..d4b19e086 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
@@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs"
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentOld struct {
InodeNumber uint32
RecordLength uint16
NameLength uint16
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentOld implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
index 934919f8a..3486864dc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -21,6 +21,8 @@ import (
// TestDirentSize tests that the dirent structs are of the correct
// size.
func TestDirentSize(t *testing.T) {
- assertSize(t, DirentOld{}, uintptr(DirentSize))
- assertSize(t, DirentNew{}, uintptr(DirentSize))
+ var dOld DirentOld
+ assertSize(t, &dOld, DirentSize)
+ var dNew DirentNew
+ assertSize(t, &dNew, DirentSize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
index bdf4e2132..0834e9ba8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
@@ -36,8 +36,6 @@
// escape analysis on an unknown implementation at compile time.
//
// Notes:
-// - All fields in these structs are exported because binary.Read would
-// panic otherwise.
// - All structures on disk are in little-endian order. Only jbd2 (journal)
// structures are in big-endian order.
// - All OS dependent fields in these structures will be interpretted using
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 4110649ab..b13999bfc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// Extents were introduced in ext4 and provide huge performance gains in terms
// data locality and reduced metadata block usage. Extents are organized in
// extent trees. The root node is contained in inode.BlocksRaw.
@@ -64,6 +68,8 @@ type ExtentNode struct {
// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
+ marshal.Marshallable
+
// FileBlock returns the first file block number covered by this entry.
FileBlock() uint32
@@ -75,6 +81,8 @@ type ExtentEntry interface {
// tree node begins with this and is followed by `NumEntries` number of:
// - Extent if `Depth` == 0
// - ExtentIdx otherwise
+//
+// +marshal
type ExtentHeader struct {
// Magic in the extent magic number, must be 0xf30a.
Magic uint16
@@ -96,6 +104,8 @@ type ExtentHeader struct {
// internal nodes. Sorted in ascending order based on FirstFileBlock since
// Linux does a binary search on this. This points to a block containing the
// child node.
+//
+// +marshal
type ExtentIdx struct {
FirstFileBlock uint32
ChildBlockLo uint32
@@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 {
// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a
// binary search on this. This points to an array of data blocks containing the
// file data. It covers `Length` data blocks starting from `StartBlock`.
+//
+// +marshal
type Extent struct {
FirstFileBlock uint32
Length uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index 8762b90db..c96002e19 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,10 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentHeaderSize)
- assertSize(t, ExtentIdx{}, ExtentEntrySize)
- assertSize(t, Extent{}, ExtentEntrySize)
+ var h ExtentHeader
+ assertSize(t, &h, ExtentHeaderSize)
+ var i ExtentIdx
+ assertSize(t, &i, ExtentEntrySize)
+ var e Extent
+ assertSize(t, &e, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
index 88ae913f5..ef25040a9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
@@ -16,6 +16,7 @@ package disklayout
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
@@ -38,6 +39,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes.
type Inode interface {
+ marshal.Marshallable
+
// Mode returns the linux file mode which is majorly used to extract
// information like:
// - File permissions (read/write/execute by user/group/others).
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
index 8f9f574ce..a4503f5cf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
@@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time"
// are used to provide nanoscond precision. Hence, these timestamps will now
// overflow in May 2446.
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+//
+// +marshal
type InodeNew struct {
InodeOld
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
index db25b11b6..e6b28babf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
@@ -30,6 +30,8 @@ const (
//
// All fields representing time are in seconds since the epoch. Which means that
// they will overflow in January 2038.
+//
+// +marshal
type InodeOld struct {
ModeRaw uint16
UIDLo uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
index dd03ee50e..90744e956 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
@@ -24,10 +24,12 @@ import (
// TestInodeSize tests that the inode structs are of the correct size.
func TestInodeSize(t *testing.T) {
- assertSize(t, InodeOld{}, OldInodeSize)
+ var iOld InodeOld
+ assertSize(t, &iOld, OldInodeSize)
// This was updated from 156 bytes to 160 bytes in Oct 2015.
- assertSize(t, InodeNew{}, 160)
+ var iNew InodeNew
+ assertSize(t, &iNew, 160)
}
// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
index 8bb327006..70948ebe9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
const (
// SbOffset is the absolute offset at which the superblock is placed.
SbOffset = 1024
@@ -38,6 +42,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block.
type SuperBlock interface {
+ marshal.Marshallable
+
// InodesCount returns the total number of inodes in this filesystem.
InodesCount() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
index 53e515fd3..4dc6080fb 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
@@ -17,6 +17,8 @@ package disklayout
// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
// RevLevel = DynamicRev and 64-bit feature is disabled.
+//
+// +marshal
type SuperBlock32Bit struct {
// We embed the old superblock struct here because the 32-bit version is just
// an extension of the old version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
index 7c1053fb4..2c9039327 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
@@ -19,6 +19,8 @@ package disklayout
// 1024 bytes (smallest possible block size) and hence the superblock always
// fits in no more than one data block. Should only be used when the 64-bit
// feature is set.
+//
+// +marshal
type SuperBlock64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
index 9221e0251..e4709f23c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
@@ -16,6 +16,8 @@ package disklayout
// SuperBlockOld implements SuperBlock and represents the old version of the
// superblock struct. Should be used only if RevLevel = OldRev.
+//
+// +marshal
type SuperBlockOld struct {
InodesCountRaw uint32
BlocksCountLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
index 463b5ba21..b734b6987 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
@@ -21,7 +21,10 @@ import (
// TestSuperBlockSize tests that the superblock structs are of the correct
// size.
func TestSuperBlockSize(t *testing.T) {
- assertSize(t, SuperBlockOld{}, 84)
- assertSize(t, SuperBlock32Bit{}, 336)
- assertSize(t, SuperBlock64Bit{}, 1024)
+ var sbOld SuperBlockOld
+ assertSize(t, &sbOld, 84)
+ var sb32 SuperBlock32Bit
+ assertSize(t, &sb32, 336)
+ var sb64 SuperBlock64Bit
+ assertSize(t, &sb64, 1024)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
index 9c63f04c0..a4bc08411 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
@@ -18,13 +18,13 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
)
-func assertSize(t *testing.T, v interface{}, want uintptr) {
+func assertSize(t *testing.T, v marshal.Marshallable, want int) {
t.Helper()
- if got := binary.Size(v); got != want {
+ if got := v.SizeBytes(); got != want {
t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got)
}
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index aca258d40..38fb7962b 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -38,9 +38,6 @@ const Name = "ext"
// +stateify savable
type FilesystemType struct{}
-// Compiles only if FilesystemType implements vfs.FilesystemType.
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// getDeviceFd returns an io.ReaderAt to the underlying device.
// Currently there are two ways of mounting an ext(2/3/4) fs:
// 1. Specify a mount with our internal special MountType in the OCI spec.
@@ -101,6 +98,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 0989558cd..d9fd4590c 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -82,6 +82,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
}
root := mntns.Root()
+ root.IncRef()
tearDown := func() {
root.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 04917d762..778460107 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -18,7 +18,6 @@ import (
"io"
"sort"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -60,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
+ f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize])
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -79,7 +78,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
+ curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize])
f.root.Entries[i].Entry = curEntry
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index cd10d46ee..985f76ac0 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -21,7 +21,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
// writeTree writes the tree represented by `root` to the inode and disk. It
// also writes random file data on disk.
func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
- rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
+ rootData := in.diskInode.Data()
+ root.Header.MarshalBytes(rootData)
+ off := root.Header.SizeBytes()
for _, ep := range root.Entries {
- rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(rootData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(in.diskInode.Data(), rootData)
-
var fileData []byte
for _, ep := range root.Entries {
if root.Header.Height == 0 {
@@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl
// writeTreeToDisk is the recursive step for writeTree which writes the tree
// on the disk only. Also writes random file data on disk.
func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
- nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
+ nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:]
+ curNode.Node.Header.MarshalBytes(nodeData)
+ off := curNode.Node.Header.SizeBytes()
for _, ep := range curNode.Node.Entries {
- nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(nodeData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
-
var fileData []byte
for _, ep := range curNode.Node.Entries {
if curNode.Node.Header.Height == 0 {
diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
index d8b728f8c..58ef7b9b8 100644
--- a/pkg/sentry/fsimpl/ext/utils.go
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -17,21 +17,21 @@ package ext
import (
"io"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// readFromDisk performs a binary read from disk into the given struct from
// the absolute offset provided.
-func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
- n := binary.Size(v)
+func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error {
+ n := v.SizeBytes()
buf := make([]byte, n)
if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
return syserror.EIO
}
- binary.Unmarshal(buf, binary.LittleEndian, v)
+ v.UnmarshalBytes(buf)
return nil
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 65786e42a..e39df21c6 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -98,6 +98,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
@@ -249,14 +252,12 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type inode struct {
inodeRefs
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.OrderedChildren
- dentry kernfs.Dentry
-
// the owning filesystem. fs is immutable.
fs *filesystem
@@ -284,26 +285,24 @@ type inode struct {
}
func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
- i := &inode{fs: fs}
+ i := &inode{fs: fs, nodeID: 1}
i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
- i.dentry.Init(i)
- i.nodeID = 1
- return &i.dentry
+ var d kernfs.Dentry
+ d.Init(&fs.Filesystem, i)
+ return &d
}
-func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry {
+func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
i := &inode{fs: fs, nodeID: nodeID}
creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
- i.dentry.Init(i)
-
- return &i.dentry
+ return i
}
// Open implements kernfs.Inode.Open.
@@ -410,23 +409,27 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
}
// Lookup implements kernfs.Inode.Lookup.
-func (i *inode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
in := linux.FUSELookupIn{Name: name}
return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in)
}
+// Keep implements kernfs.Inode.Keep.
+func (i *inode) Keep() bool {
+ // Return true so that kernfs keeps the new dentry pointing to this
+ // inode in the dentry tree. This is needed because inodes created via
+ // Lookup are not temporary. They might refer to existing files on server
+ // that can be Unlink'd/Rmdir'd.
+ return true
+}
+
// IterDirents implements kernfs.Inode.IterDirents.
func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
-// Valid implements kernfs.Inode.Valid.
-func (*inode) Valid(ctx context.Context) bool {
- return true
-}
-
// NewFile implements kernfs.Inode.NewFile.
-func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID)
@@ -444,7 +447,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions)
}
// NewNode implements kernfs.Inode.NewNode.
-func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (kernfs.Inode, error) {
in := linux.FUSEMknodIn{
MknodMeta: linux.FUSEMknodMeta{
Mode: uint32(opts.Mode),
@@ -457,7 +460,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions)
}
// NewSymlink implements kernfs.Inode.NewSymlink.
-func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.Dentry, error) {
+func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) {
in := linux.FUSESymLinkIn{
Name: name,
Target: target,
@@ -466,7 +469,7 @@ func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.De
}
// Unlink implements kernfs.Inode.Unlink.
-func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) error {
+func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) error {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
@@ -482,14 +485,11 @@ func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) e
return err
}
// only return error, discard res.
- if err := res.Error(); err != nil {
- return err
- }
- return i.dentry.RemoveChildLocked(name, child)
+ return res.Error()
}
// NewDir implements kernfs.Inode.NewDir.
-func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
+func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
in := linux.FUSEMkdirIn{
MkdirMeta: linux.FUSEMkdirMeta{
Mode: uint32(opts.Mode),
@@ -501,7 +501,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions)
}
// RmDir implements kernfs.Inode.RmDir.
-func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) error {
+func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) error {
fusefs := i.fs
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
@@ -515,16 +515,12 @@ func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) er
if err != nil {
return err
}
- if err := res.Error(); err != nil {
- return err
- }
-
- return i.dentry.RemoveChildLocked(name, child)
+ return res.Error()
}
// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response.
// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP.
-func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*kernfs.Dentry, error) {
+func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (kernfs.Inode, error) {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
@@ -734,8 +730,8 @@ func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptio
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *inode) DecRef(context.Context) {
- i.inodeRefs.DecRef(i.Destroy)
+func (i *inode) DecRef(ctx context.Context) {
+ i.inodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 16787116f..ad0afc41b 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -52,6 +52,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/p9",
+ "//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 8608471f8..f1dad1b08 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -272,6 +272,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index eeaf6e444..f8b19bae7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -395,7 +395,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
End: gapEnd,
}
optMR := gap.Range()
- err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, h.readToBlocksAt)
+ err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), rw.d.size, mf, usage.PageCache, h.readToBlocksAt)
mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End})
seg, gap = rw.d.cache.Find(rw.off)
if !seg.Ok() {
@@ -403,10 +403,10 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
rw.d.handleMu.RUnlock()
return done, err
}
- // err might have occurred in part of gap.Range() outside
- // gapMR. Forget about it for now; if the error matters and
- // persists, we'll run into it again in a later iteration of
- // this loop.
+ // err might have occurred in part of gap.Range() outside gapMR
+ // (in particular, gap.End() might be beyond EOF). Forget about
+ // it for now; if the error matters and persists, we'll run
+ // into it again in a later iteration of this loop.
} else {
// Read directly from the file.
gapDsts := dsts.TakeFirst64(gapMR.Length())
@@ -780,7 +780,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab
mf := d.fs.mfp.MemoryFile()
h := d.readHandleLocked()
- cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, h.readToBlocksAt)
+ cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), d.size, mf, usage.PageCache, h.readToBlocksAt)
var ts []memmap.Translation
var translatedEnd uint64
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index ffe4ddb32..698e913fe 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -118,7 +118,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
if err != nil {
return nil, err
}
- d.Init(i)
+ d.Init(&fs.Filesystem, i)
// i.open will take a reference on d.
defer d.DecRef(ctx)
@@ -151,6 +151,9 @@ func (filesystemType) Name() string {
return "none"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// NewFilesystem sets up and returns a new hostfs filesystem.
//
// Note that there should only ever be one instance of host.filesystem,
@@ -195,6 +198,7 @@ type inode struct {
kernfs.InodeNoStatFS
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
locks vfs.FileLocks
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 131145b85..8a447e29f 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -348,10 +348,10 @@ func (e *SCMConnectedEndpoint) Init() error {
func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
e.DecRef(func() {
e.mu.Lock()
+ fdnotifier.RemoveFD(int32(e.fd))
if err := syscall.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
}
- fdnotifier.RemoveFD(int32(e.fd))
e.destroyLocked()
e.mu.Unlock()
})
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 5e91e0536..858cc24ce 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -70,6 +70,17 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "synthetic_directory_refs",
+ out = "synthetic_directory_refs.go",
+ package = "kernfs",
+ prefix = "syntheticDirectory",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "syntheticDirectory",
+ },
+)
+
go_library(
name = "kernfs",
srcs = [
@@ -84,6 +95,7 @@ go_library(
"static_directory_refs.go",
"symlink.go",
"synthetic_directory.go",
+ "synthetic_directory_refs.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 0a4cd4057..abf1905d6 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -201,12 +201,12 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// these.
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
- stat, err := it.Dentry.inode.Stat(ctx, fd.filesystem(), opts)
+ stat, err := it.inode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
dirent := vfs.Dirent{
- Name: it.Name,
+ Name: it.name,
Type: linux.FileMode(stat.Mode).DirentType(),
Ino: stat.Ino,
NextOff: fd.off + 1,
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 5cc1c4281..6426a55f6 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -89,7 +89,7 @@ afterSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef(ctx)
+ fs.deferDecRefVD(ctx, targetVD)
if err != nil {
return nil, err
}
@@ -120,22 +120,33 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// Cached dentry exists, revalidate.
if !child.inode.Valid(ctx) {
delete(parent.children, name)
- vfsObj.InvalidateDentry(ctx, &child.vfsd)
- fs.deferDecRef(child) // Reference from Lookup.
+ if child.inode.Keep() {
+ // Drop the ref owned by kernfs.
+ fs.deferDecRef(child)
+ }
+ vfsObj.InvalidateDentry(ctx, child.VFSDentry())
child = nil
}
}
if child == nil {
// Dentry isn't cached; it either doesn't exist or failed revalidation.
// Attempt to resolve it via Lookup.
- c, err := parent.inode.Lookup(ctx, name)
+ childInode, err := parent.inode.Lookup(ctx, name)
if err != nil {
return nil, err
}
- // Reference on c (provided by Lookup) will be dropped when the dentry
- // fails validation.
- parent.InsertChildLocked(name, c)
- child = c
+ var newChild Dentry
+ newChild.Init(fs, childInode) // childInode's ref is transferred to newChild.
+ parent.insertChildLocked(name, &newChild)
+ child = &newChild
+
+ // Drop the ref on newChild. This will cause the dentry to get pruned
+ // from the dentry tree by the end of current filesystem operation
+ // (before returning to the VFS layer) if another ref is not picked on
+ // this dentry.
+ if !childInode.Keep() {
+ fs.deferDecRef(&newChild)
+ }
}
return child, nil
}
@@ -191,7 +202,7 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
}
// checkCreateLocked checks that a file named rp.Component() may be created in
-// directory parentVFSD, then returns rp.Component().
+// directory parent, then returns rp.Component().
//
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
@@ -298,9 +309,9 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -324,11 +335,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EPERM
}
- child, err := parent.inode.NewLink(ctx, pc, d.inode)
+ childI, err := parent.inode.NewLink(ctx, pc, d.inode)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
@@ -338,9 +351,9 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -355,14 +368,16 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- child, err := parent.inode.NewDir(ctx, pc, opts)
+ childI, err := parent.inode.NewDir(ctx, pc, opts)
if err != nil {
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
- child = newSyntheticDirectory(rp.Credentials(), opts.Mode)
+ childI = newSyntheticDirectory(rp.Credentials(), opts.Mode)
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
@@ -372,9 +387,9 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -389,11 +404,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- newD, err := parent.inode.NewNode(ctx, pc, opts)
+ newI, err := parent.inode.NewNode(ctx, pc, opts)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, newD)
+ var newD Dentry
+ newD.Init(fs, newI)
+ parent.insertChildLocked(pc, &newD)
return nil
}
@@ -409,22 +426,23 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
return nil, err
}
if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
return nil, err
}
- d.inode.IncRef()
- defer d.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef d to prevent
+ // its destruction while fs.mu is unlocked.
+ d.IncRef()
fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
- return d.inode.Open(ctx, rp, d, opts)
+ fd, err := d.inode.Open(ctx, rp, d, opts)
+ d.DecRef(ctx)
+ return fd, err
}
// May create new file.
@@ -438,6 +456,10 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
unlocked = true
}
}
+ // Process all to-be-decref'd dentries at the end at once.
+ // Since we defer unlock() AFTER this, fs.mu is guaranteed to be unlocked
+ // when this is executed.
+ defer fs.processDeferredDecRefs(ctx)
defer unlock()
if rp.Done() {
if rp.MustBeDir() {
@@ -449,14 +471,16 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- d.inode.IncRef()
- defer d.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef d to prevent
+ // its destruction while fs.mu is unlocked.
+ d.IncRef()
unlock()
- return d.inode.Open(ctx, rp, d, opts)
+ fd, err := d.inode.Open(ctx, rp, d, opts)
+ d.DecRef(ctx)
+ return fd, err
}
afterTrailingSymlink:
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return nil, err
}
@@ -487,18 +511,23 @@ afterTrailingSymlink:
}
defer rp.Mount().EndWrite()
// Create and open the child.
- child, err := parent.inode.NewFile(ctx, pc, opts)
+ childI, err := parent.inode.NewFile(ctx, pc, opts)
if err != nil {
return nil, err
}
+ var child Dentry
+ child.Init(fs, childI)
// FIXME(gvisor.dev/issue/1193): Race between checking existence with
- // fs.stepExistingLocked and parent.InsertChild. If possible, we should hold
+ // fs.stepExistingLocked and parent.insertChild. If possible, we should hold
// dirMu from one to the other.
- parent.InsertChild(pc, child)
- child.inode.IncRef()
- defer child.inode.DecRef(ctx)
+ parent.insertChild(pc, &child)
+ // Open may block so we need to unlock fs.mu. IncRef child to prevent
+ // its destruction while fs.mu is unlocked.
+ child.IncRef()
unlock()
- return child.inode.Open(ctx, rp, child, opts)
+ fd, err := child.inode.Open(ctx, rp, &child, opts)
+ child.DecRef(ctx)
+ return fd, err
}
if err != nil {
return nil, err
@@ -514,7 +543,7 @@ afterTrailingSymlink:
}
if targetVD.Ok() {
err := rp.HandleJump(targetVD)
- targetVD.DecRef(ctx)
+ fs.deferDecRefVD(ctx, targetVD)
if err != nil {
return nil, err
}
@@ -530,18 +559,21 @@ afterTrailingSymlink:
if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- child.inode.IncRef()
- defer child.inode.DecRef(ctx)
+ // Open may block so we need to unlock fs.mu. IncRef child to prevent
+ // its destruction while fs.mu is unlocked.
+ child.IncRef()
unlock()
- return child.inode.Open(ctx, rp, child, opts)
+ fd, err := child.inode.Open(ctx, rp, child, opts)
+ child.DecRef(ctx)
+ return fd, err
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -560,7 +592,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
fs.mu.Lock()
- defer fs.processDeferredDecRefsLocked(ctx)
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
// Resolve the destination directory first to verify that it's on this
@@ -632,24 +664,27 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
return err
}
- replaced, err := srcDir.inode.Rename(ctx, src.name, pc, src, dstDir)
+ err = srcDir.inode.Rename(ctx, src.name, pc, src.inode, dstDir.inode)
if err != nil {
virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
return err
}
delete(srcDir.children, src.name)
if srcDir != dstDir {
- fs.deferDecRef(srcDir)
- dstDir.IncRef()
+ fs.deferDecRef(srcDir) // child (src) drops ref on old parent.
+ dstDir.IncRef() // child (src) takes a ref on the new parent.
}
src.parent = dstDir
src.name = pc
if dstDir.children == nil {
dstDir.children = make(map[string]*Dentry)
}
+ replaced := dstDir.children[pc]
dstDir.children[pc] = src
var replaceVFSD *vfs.Dentry
if replaced != nil {
+ // deferDecRef so that fs.mu and dstDir.mu are unlocked by then.
+ fs.deferDecRef(replaced)
replaceVFSD = replaced.VFSDentry()
}
virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD)
@@ -659,10 +694,10 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -691,10 +726,13 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
- if err := parentDentry.inode.RmDir(ctx, d.name, d); err != nil {
+ if err := parentDentry.inode.RmDir(ctx, d.name, d.inode); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
+ delete(parentDentry.children, d.name)
+ // Defer decref so that fs.mu and parentDentry.dirMu are unlocked by then.
+ fs.deferDecRef(d)
virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -702,9 +740,9 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -717,9 +755,9 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
// StatAt implements vfs.FilesystemImpl.StatAt.
func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statx{}, err
}
@@ -729,9 +767,9 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statfs{}, err
}
@@ -744,9 +782,9 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return syserror.EEXIST
}
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
parent, err := fs.walkParentDirLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -761,21 +799,23 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return err
}
defer rp.Mount().EndWrite()
- child, err := parent.inode.NewSymlink(ctx, pc, target)
+ childI, err := parent.inode.NewSymlink(ctx, pc, target)
if err != nil {
return err
}
- parent.InsertChildLocked(pc, child)
+ var child Dentry
+ child.Init(fs, childI)
+ parent.insertChildLocked(pc, &child)
return nil
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
+ defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
@@ -799,10 +839,13 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
- if err := parentDentry.inode.Unlink(ctx, d.name, d); err != nil {
+ if err := parentDentry.inode.Unlink(ctx, d.name, d.inode); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
+ delete(parentDentry.children, d.name)
+ // Defer decref so that fs.mu and parentDentry.dirMu are unlocked by then.
+ fs.deferDecRef(d)
virtfs.CommitDeleteDentry(ctx, vfsd)
return nil
}
@@ -810,9 +853,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
d, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -825,9 +868,9 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
@@ -838,9 +881,9 @@ func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
@@ -851,9 +894,9 @@ func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -864,9 +907,9 @@ func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
_, err := fs.walkExistingLocked(ctx, rp)
- fs.mu.RUnlock()
- fs.processDeferredDecRefs(ctx)
if err != nil {
return err
}
@@ -880,3 +923,16 @@ func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.mu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*Dentry), b)
}
+
+func (fs *Filesystem) deferDecRefVD(ctx context.Context, vd vfs.VirtualDentry) {
+ if d, ok := vd.Dentry().Impl().(*Dentry); ok && d.fs == fs {
+ // The following is equivalent to vd.DecRef(ctx). This is needed
+ // because if d belongs to this filesystem, we can not DecRef it right
+ // away as we may be holding fs.mu. d.DecRef may acquire fs.mu. So we
+ // defer the DecRef to when locks are dropped.
+ vd.Mount().DecRef(ctx)
+ fs.deferDecRef(d)
+ } else {
+ vd.DecRef(ctx)
+ }
+}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 49210e748..122b10591 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -34,6 +34,7 @@ import (
//
// +stateify savable
type InodeNoopRefCount struct {
+ InodeTemporary
}
// IncRef implements Inode.IncRef.
@@ -57,27 +58,27 @@ func (InodeNoopRefCount) TryIncRef() bool {
type InodeDirectoryNoNewChildren struct{}
// NewFile implements Inode.NewFile.
-func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewLink implements Inode.NewLink.
-func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (Inode, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (Inode, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (Inode, error) {
return nil, syserror.EPERM
}
@@ -88,6 +89,7 @@ func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOpt
//
// +stateify savable
type InodeNotDirectory struct {
+ InodeAlwaysValid
}
// HasChildren implements Inode.HasChildren.
@@ -96,47 +98,47 @@ func (InodeNotDirectory) HasChildren() bool {
}
// NewFile implements Inode.NewFile.
-func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (Inode, error) {
panic("NewFile called on non-directory inode")
}
// NewDir implements Inode.NewDir.
-func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (Inode, error) {
panic("NewDir called on non-directory inode")
}
// NewLink implements Inode.NewLinkink.
-func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*Dentry, error) {
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (Inode, error) {
panic("NewLink called on non-directory inode")
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*Dentry, error) {
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (Inode, error) {
panic("NewSymlink called on non-directory inode")
}
// NewNode implements Inode.NewNode.
-func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (Inode, error) {
panic("NewNode called on non-directory inode")
}
// Unlink implements Inode.Unlink.
-func (InodeNotDirectory) Unlink(context.Context, string, *Dentry) error {
+func (InodeNotDirectory) Unlink(context.Context, string, Inode) error {
panic("Unlink called on non-directory inode")
}
// RmDir implements Inode.RmDir.
-func (InodeNotDirectory) RmDir(context.Context, string, *Dentry) error {
+func (InodeNotDirectory) RmDir(context.Context, string, Inode) error {
panic("RmDir called on non-directory inode")
}
// Rename implements Inode.Rename.
-func (InodeNotDirectory) Rename(context.Context, string, string, *Dentry, *Dentry) (*Dentry, error) {
+func (InodeNotDirectory) Rename(context.Context, string, string, Inode, Inode) error {
panic("Rename called on non-directory inode")
}
// Lookup implements Inode.Lookup.
-func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*Dentry, error) {
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) {
panic("Lookup called on non-directory inode")
}
@@ -145,35 +147,6 @@ func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDiren
panic("IterDirents called on non-directory inode")
}
-// Valid implements Inode.Valid.
-func (InodeNotDirectory) Valid(context.Context) bool {
- return true
-}
-
-// InodeNoDynamicLookup partially implements the Inode interface, specifically
-// the inodeDynamicLookup sub interface. Directory inodes that do not support
-// dymanic entries (i.e. entries that are not "hashed" into the
-// vfs.Dentry.children) can embed this to provide no-op implementations for
-// functions related to dynamic entries.
-//
-// +stateify savable
-type InodeNoDynamicLookup struct{}
-
-// Lookup implements Inode.Lookup.
-func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*Dentry, error) {
- return nil, syserror.ENOENT
-}
-
-// IterDirents implements Inode.IterDirents.
-func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- return offset, nil
-}
-
-// Valid implements Inode.Valid.
-func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
- return true
-}
-
// InodeNotSymlink partially implements the Inode interface, specifically the
// inodeSymlink sub interface. All inodes that are not symlinks may embed this
// to return the appropriate errors from symlink-related functions.
@@ -273,7 +246,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
// This function can be used by other kernfs-based filesystem implementation to
-// sets the unexported attributes into kernfs.InodeAttrs.
+// sets the unexported attributes into InodeAttrs.
func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
@@ -344,8 +317,9 @@ func (a *InodeAttrs) DecLinks() {
// +stateify savable
type slot struct {
- Name string
- Dentry *Dentry
+ name string
+ inode Inode
+ static bool
slotEntry
}
@@ -361,10 +335,18 @@ type OrderedChildrenOptions struct {
}
// OrderedChildren partially implements the Inode interface. OrderedChildren can
-// be embedded in directory inodes to keep track of the children in the
+// be embedded in directory inodes to keep track of children in the
// directory, and can then be used to implement a generic directory FD -- see
-// GenericDirectoryFD. OrderedChildren is not compatible with dynamic
-// directories.
+// GenericDirectoryFD.
+//
+// OrderedChildren can represent a node in an Inode tree. The children inodes
+// might be directories themselves using OrderedChildren; hence extending the
+// tree. The parent inode (OrderedChildren user) holds a ref on all its static
+// children. This lets the static inodes outlive their associated dentry.
+// While the dentry might have to be regenerated via a Lookup() call, we can
+// keep reusing the same static inode. These static children inodes are finally
+// DecRef'd when this directory inode is being destroyed. This makes
+// OrderedChildren suitable for static directory entries as well.
//
// Must be initialize with Init before first use.
//
@@ -388,33 +370,63 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
// Destroy clears the children stored in o. It should be called by structs
// embedding OrderedChildren upon destruction, i.e. when their reference count
// reaches zero.
-func (o *OrderedChildren) Destroy() {
+func (o *OrderedChildren) Destroy(ctx context.Context) {
o.mu.Lock()
defer o.mu.Unlock()
+ // Drop the ref that o owns on the static inodes it holds.
+ for _, s := range o.set {
+ if s.static {
+ s.inode.DecRef(ctx)
+ }
+ }
o.order.Reset()
o.set = nil
}
-// Populate inserts children into this OrderedChildren, and d's dentry
-// cache. Populate returns the number of directories inserted, which the caller
+// Populate inserts static children into this OrderedChildren.
+// Populate returns the number of directories inserted, which the caller
// may use to update the link count for the parent directory.
//
-// Precondition: d must represent a directory inode. children must not contain
-// any conflicting entries already in o.
-func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
+// Precondition:
+// * d must represent a directory inode.
+// * children must not contain any conflicting entries already in o.
+// * Caller must hold a reference on all inodes passed.
+//
+// Postcondition: Caller's references on inodes are transferred to o.
+func (o *OrderedChildren) Populate(children map[string]Inode) uint32 {
var links uint32
for name, child := range children {
- if child.isDir() {
+ if child.Mode().IsDir() {
links++
}
- if err := o.Insert(name, child); err != nil {
- panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
+ if err := o.insert(name, child, true); err != nil {
+ panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v)", name, child))
}
- d.InsertChild(name, child)
}
return links
}
+// Lookup implements Inode.Lookup.
+func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error) {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+
+ s, ok := o.set[name]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+
+ s.inode.IncRef() // This ref is passed to the dentry upon creation via Init.
+ return s.inode, nil
+}
+
+// IterDirents implements Inode.IterDirents.
+func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ // All entries from OrderedChildren have already been handled in
+ // GenericDirectoryFD.IterDirents.
+ return offset, nil
+}
+
// HasChildren implements Inode.HasChildren.
func (o *OrderedChildren) HasChildren() bool {
o.mu.RLock()
@@ -422,17 +434,27 @@ func (o *OrderedChildren) HasChildren() bool {
return len(o.set) > 0
}
-// Insert inserts child into o. This ignores the writability of o, as this is
-// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
-func (o *OrderedChildren) Insert(name string, child *Dentry) error {
+// Insert inserts a dynamic child into o. This ignores the writability of o, as
+// this is not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
+func (o *OrderedChildren) Insert(name string, child Inode) error {
+ return o.insert(name, child, false)
+}
+
+// insert inserts child into o.
+//
+// Precondition: Caller must be holding a ref on child if static is true.
+//
+// Postcondition: Caller's ref on child is transferred to o if static is true.
+func (o *OrderedChildren) insert(name string, child Inode, static bool) error {
o.mu.Lock()
defer o.mu.Unlock()
if _, ok := o.set[name]; ok {
return syserror.EEXIST
}
s := &slot{
- Name: name,
- Dentry: child,
+ name: name,
+ inode: child,
+ static: static,
}
o.order.PushBack(s)
o.set[name] = s
@@ -442,44 +464,49 @@ func (o *OrderedChildren) Insert(name string, child *Dentry) error {
// Precondition: caller must hold o.mu for writing.
func (o *OrderedChildren) removeLocked(name string) {
if s, ok := o.set[name]; ok {
+ if s.static {
+ panic(fmt.Sprintf("removeLocked called on a static inode: %v", s.inode))
+ }
delete(o.set, name)
o.order.Remove(s)
}
}
// Precondition: caller must hold o.mu for writing.
-func (o *OrderedChildren) replaceChildLocked(name string, new *Dentry) *Dentry {
+func (o *OrderedChildren) replaceChildLocked(ctx context.Context, name string, newI Inode) {
if s, ok := o.set[name]; ok {
+ if s.static {
+ panic(fmt.Sprintf("replacing a static inode: %v", s.inode))
+ }
+
// Existing slot with given name, simply replace the dentry.
- var old *Dentry
- old, s.Dentry = s.Dentry, new
- return old
+ s.inode = newI
}
// No existing slot with given name, create and hash new slot.
s := &slot{
- Name: name,
- Dentry: new,
+ name: name,
+ inode: newI,
+ static: false,
}
o.order.PushBack(s)
o.set[name] = s
- return nil
}
// Precondition: caller must hold o.mu for reading or writing.
-func (o *OrderedChildren) checkExistingLocked(name string, child *Dentry) error {
+func (o *OrderedChildren) checkExistingLocked(name string, child Inode) error {
s, ok := o.set[name]
if !ok {
return syserror.ENOENT
}
- if s.Dentry != child {
- panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child))
+ if s.inode != child {
+ panic(fmt.Sprintf("Inode doesn't match what kernfs thinks! OrderedChild: %+v, kernfs: %+v", s.inode, child))
}
return nil
}
// Unlink implements Inode.Unlink.
-func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry) error {
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child Inode) error {
if !o.writable {
return syserror.EPERM
}
@@ -494,8 +521,8 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry
return nil
}
-// Rmdir implements Inode.Rmdir.
-func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *Dentry) error {
+// RmDir implements Inode.RmDir.
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) error {
// We're not responsible for checking that child is a directory, that it's
// empty, or updating any link counts; so this is the same as unlink.
return o.Unlink(ctx, name, child)
@@ -517,13 +544,13 @@ func (renameAcrossDifferentImplementationsError) Error() string {
// that will support Rename.
//
// Postcondition: reference on any replaced dentry transferred to caller.
-func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (*Dentry, error) {
- dst, ok := dstDir.inode.(interface{}).(*OrderedChildren)
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error {
+ dst, ok := dstDir.(interface{}).(*OrderedChildren)
if !ok {
- return nil, renameAcrossDifferentImplementationsError{}
+ return renameAcrossDifferentImplementationsError{}
}
if !o.writable || !dst.writable {
- return nil, syserror.EPERM
+ return syserror.EPERM
}
// Note: There's a potential deadlock below if concurrent calls to Rename
// refer to the same src and dst directories in reverse. We avoid any
@@ -536,12 +563,12 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c
defer dst.mu.Unlock()
}
if err := o.checkExistingLocked(oldname, child); err != nil {
- return nil, err
+ return err
}
// TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
- replaced := dst.replaceChildLocked(newname, child)
- return replaced, nil
+ dst.replaceChildLocked(ctx, newname, child)
+ return nil
}
// nthLocked returns an iterator to the nth child tracked by this object. The
@@ -576,11 +603,12 @@ func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry,
//
// +stateify savable
type StaticDirectory struct {
+ InodeAlwaysValid
InodeAttrs
InodeDirectoryNoNewChildren
- InodeNoDynamicLookup
InodeNoStatFS
InodeNotSymlink
+ InodeTemporary
OrderedChildren
StaticDirectoryRefs
@@ -591,19 +619,16 @@ type StaticDirectory struct {
var _ Inode = (*StaticDirectory)(nil)
// NewStaticDir creates a new static directory and returns its dentry.
-func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry {
+func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
inode.EnableLeakCheck()
- dentry := &Dentry{}
- dentry.Init(inode)
-
inode.OrderedChildren.Init(OrderedChildrenOptions{})
- links := inode.OrderedChildren.Populate(dentry, children)
+ links := inode.OrderedChildren.Populate(children)
inode.IncLinks(links)
- return dentry
+ return inode
}
// Init initializes StaticDirectory.
@@ -615,7 +640,7 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3
s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
-// Open implements kernfs.Inode.Open.
+// Open implements Inode.Open.
func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd, err := NewGenericDirectoryFD(rp.Mount(), d, &s.OrderedChildren, &s.locks, &opts, s.fdOpts)
if err != nil {
@@ -624,26 +649,36 @@ func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *De
return fd.VFSFileDescription(), nil
}
-// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
-// DecRef implements kernfs.Inode.DecRef.
-func (s *StaticDirectory) DecRef(context.Context) {
- s.StaticDirectoryRefs.DecRef(s.Destroy)
+// DecRef implements Inode.DecRef.
+func (s *StaticDirectory) DecRef(ctx context.Context) {
+ s.StaticDirectoryRefs.DecRef(func() { s.Destroy(ctx) })
}
-// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+// InodeAlwaysValid partially implements Inode.
//
// +stateify savable
-type AlwaysValid struct{}
+type InodeAlwaysValid struct{}
-// Valid implements kernfs.inodeDynamicLookup.Valid.
-func (*AlwaysValid) Valid(context.Context) bool {
+// Valid implements Inode.Valid.
+func (*InodeAlwaysValid) Valid(context.Context) bool {
return true
}
+// InodeTemporary partially implements Inode.
+//
+// +stateify savable
+type InodeTemporary struct{}
+
+// Keep implements Inode.Keep.
+func (*InodeTemporary) Keep() bool {
+ return false
+}
+
// InodeNoStatFS partially implements the Inode interface, where the client
// filesystem doesn't support statfs(2).
//
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 6d3d79333..606081e68 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -29,12 +29,16 @@
//
// Reference Model:
//
-// Kernfs dentries represents named pointers to inodes. Dentries and inodes have
+// Kernfs dentries represents named pointers to inodes. Kernfs is solely
+// reponsible for maintaining and modifying its dentry tree; inode
+// implementations can not access the tree. Dentries and inodes have
// independent lifetimes and reference counts. A child dentry unconditionally
// holds a reference on its parent directory's dentry. A dentry also holds a
-// reference on the inode it points to. Multiple dentries can point to the same
-// inode (for example, in the case of hardlinks). File descriptors hold a
-// reference to the dentry they're opened on.
+// reference on the inode it points to (although that might not be the only
+// reference on the inode). Due to this inodes can outlive the dentries that
+// point to them. Multiple dentries can point to the same inode (for example,
+// in the case of hardlinks). File descriptors hold a reference to the dentry
+// they're opened on.
//
// Dentries are guaranteed to exist while holding Filesystem.mu for
// reading. Dropping dentries require holding Filesystem.mu for writing. To
@@ -47,8 +51,8 @@
// kernfs.Dentry.dirMu
// vfs.VirtualFilesystem.mountMu
// vfs.Dentry.mu
-// kernfs.Filesystem.droppedDentriesMu
// (inode implementation locks, if any)
+// kernfs.Filesystem.droppedDentriesMu
package kernfs
import (
@@ -60,7 +64,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
@@ -95,7 +98,7 @@ type Filesystem struct {
// example:
//
// fs.mu.RLock()
- // fs.mu.processDeferredDecRefs()
+ // defer fs.processDeferredDecRefs()
// defer fs.mu.RUnlock()
// ...
// fs.deferDecRef(dentry)
@@ -108,8 +111,7 @@ type Filesystem struct {
// deferDecRef defers dropping a dentry ref until the next call to
// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
-//
-// Precondition: d must not already be pending destruction.
+// This may be called while Filesystem.mu or Dentry.dirMu is locked.
func (fs *Filesystem) deferDecRef(d *Dentry) {
fs.droppedDentriesMu.Lock()
fs.droppedDentries = append(fs.droppedDentries, d)
@@ -118,17 +120,14 @@ func (fs *Filesystem) deferDecRef(d *Dentry) {
// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
// droppedDentries list. See comment on Filesystem.mu.
+//
+// Precondition: Filesystem.mu or Dentry.dirMu must NOT be locked.
func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) {
- fs.mu.Lock()
- fs.processDeferredDecRefsLocked(ctx)
- fs.mu.Unlock()
-}
-
-// Precondition: fs.mu must be held for writing.
-func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) {
fs.droppedDentriesMu.Lock()
for _, d := range fs.droppedDentries {
- d.DecRef(ctx)
+ // Defer the DecRef call so that we are not holding droppedDentriesMu
+ // when DecRef is called.
+ defer d.DecRef(ctx)
}
fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
fs.droppedDentriesMu.Unlock()
@@ -157,17 +156,19 @@ const (
//
// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a
// named reference to an inode. A dentry generally lives as long as it's part of
-// a mounted filesystem tree. Kernfs doesn't cache dentries once all references
-// to them are removed. Dentries hold a single reference to the inode they point
+// a mounted filesystem tree. Kernfs drops dentries once all references to them
+// are dropped. Dentries hold a single reference to the inode they point
// to, and child dentries hold a reference on their parent.
//
// Must be initialized by Init prior to first use.
//
// +stateify savable
type Dentry struct {
+ vfsd vfs.Dentry
DentryRefs
- vfsd vfs.Dentry
+ // fs is the owning filesystem. fs is immutable.
+ fs *Filesystem
// flags caches useful information about the dentry from the inode. See the
// dflags* consts above. Must be accessed by atomic ops.
@@ -192,8 +193,9 @@ type Dentry struct {
// Precondition: Caller must hold a reference on inode.
//
// Postcondition: Caller's reference on inode is transferred to the dentry.
-func (d *Dentry) Init(inode Inode) {
+func (d *Dentry) Init(fs *Filesystem, inode Inode) {
d.vfsd.Init(d)
+ d.fs = fs
d.inode = inode
ftype := inode.Mode().FileType()
if ftype == linux.ModeDirectory {
@@ -222,14 +224,28 @@ func (d *Dentry) isSymlink() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *Dentry) DecRef(ctx context.Context) {
- // Before the destructor is called, Dentry must be removed from VFS' dentry cache.
+ decRefParent := false
+ d.fs.mu.Lock()
d.DentryRefs.DecRef(func() {
d.inode.DecRef(ctx) // IncRef from Init.
d.inode = nil
if d.parent != nil {
- d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild.
+ // We will DecRef d.parent once all locks are dropped.
+ decRefParent = true
+ d.parent.dirMu.Lock()
+ // Remove d from parent.children. It might already have been
+ // removed due to invalidation.
+ if _, ok := d.parent.children[d.name]; ok {
+ delete(d.parent.children, d.name)
+ d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
+ }
+ d.parent.dirMu.Unlock()
}
})
+ d.fs.mu.Unlock()
+ if decRefParent {
+ d.parent.DecRef(ctx) // IncRef from Dentry.insertChild.
+ }
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -247,26 +263,26 @@ func (d *Dentry) Watches() *vfs.Watches {
// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
func (d *Dentry) OnZeroWatches(context.Context) {}
-// InsertChild inserts child into the vfs dentry cache with the given name under
+// insertChild inserts child into the vfs dentry cache with the given name under
// this dentry. This does not update the directory inode, so calling this on its
// own isn't sufficient to insert a child into a directory.
//
// Precondition: d must represent a directory inode.
-func (d *Dentry) InsertChild(name string, child *Dentry) {
+func (d *Dentry) insertChild(name string, child *Dentry) {
d.dirMu.Lock()
- d.InsertChildLocked(name, child)
+ d.insertChildLocked(name, child)
d.dirMu.Unlock()
}
-// InsertChildLocked is equivalent to InsertChild, with additional
+// insertChildLocked is equivalent to insertChild, with additional
// preconditions.
//
// Preconditions:
// * d must represent a directory inode.
// * d.dirMu must be locked.
-func (d *Dentry) InsertChildLocked(name string, child *Dentry) {
+func (d *Dentry) insertChildLocked(name string, child *Dentry) {
if !d.isDir() {
- panic(fmt.Sprintf("InsertChildLocked called on non-directory Dentry: %+v.", d))
+ panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d))
}
d.IncRef() // DecRef in child's Dentry.destroy.
child.parent = d
@@ -277,36 +293,6 @@ func (d *Dentry) InsertChildLocked(name string, child *Dentry) {
d.children[name] = child
}
-// RemoveChild removes child from the vfs dentry cache. This does not update the
-// directory inode or modify the inode to be unlinked. So calling this on its own
-// isn't sufficient to remove a child from a directory.
-//
-// Precondition: d must represent a directory inode.
-func (d *Dentry) RemoveChild(name string, child *Dentry) error {
- d.dirMu.Lock()
- defer d.dirMu.Unlock()
- return d.RemoveChildLocked(name, child)
-}
-
-// RemoveChildLocked is equivalent to RemoveChild, with additional
-// preconditions.
-//
-// Precondition: d.dirMu must be locked.
-func (d *Dentry) RemoveChildLocked(name string, child *Dentry) error {
- if !d.isDir() {
- panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d))
- }
- c, ok := d.children[name]
- if !ok {
- return syserror.ENOENT
- }
- if c != child {
- panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child))
- }
- delete(d.children, name)
- return nil
-}
-
// Inode returns the dentry's inode.
func (d *Dentry) Inode() Inode {
return d.inode
@@ -348,11 +334,6 @@ type Inode interface {
// a blanket implementation for all non-directory inodes.
inodeDirectory
- // Method for inodes that represent dynamic directories and their
- // children. InodeNoDynamicLookup provides a blanket implementation for all
- // non-dynamic-directory inodes.
- inodeDynamicLookup
-
// Open creates a file description for the filesystem object represented by
// this inode. The returned file description should hold a reference on the
// dentry for its lifetime.
@@ -365,6 +346,14 @@ type Inode interface {
// corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem
// doesn't support statfs(2), this should return ENOSYS.
StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error)
+
+ // Keep indicates whether the dentry created after Inode.Lookup should be
+ // kept in the kernfs dentry tree.
+ Keep() bool
+
+ // Valid should return true if this inode is still valid, or needs to
+ // be resolved again by a call to Lookup.
+ Valid(ctx context.Context) bool
}
type inodeRefs interface {
@@ -397,8 +386,8 @@ type inodeMetadata interface {
// Precondition: All methods in this interface may only be called on directory
// inodes.
type inodeDirectory interface {
- // The New{File,Dir,Node,Symlink} methods below should return a new inode
- // hashed into this inode.
+ // The New{File,Dir,Node,Link,Symlink} methods below should return a new inode
+ // that will be hashed into the dentry tree.
//
// These inode constructors are inode-level operations rather than
// filesystem-level operations to allow client filesystems to mix different
@@ -409,60 +398,54 @@ type inodeDirectory interface {
HasChildren() bool
// NewFile creates a new regular file inode.
- NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error)
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (Inode, error)
// NewDir creates a new directory inode.
- NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error)
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (Inode, error)
// NewLink creates a new hardlink to a specified inode in this
// directory. Implementations should create a new kernfs Dentry pointing to
// target, and update target's link count.
- NewLink(ctx context.Context, name string, target Inode) (*Dentry, error)
+ NewLink(ctx context.Context, name string, target Inode) (Inode, error)
// NewSymlink creates a new symbolic link inode.
- NewSymlink(ctx context.Context, name, target string) (*Dentry, error)
+ NewSymlink(ctx context.Context, name, target string) (Inode, error)
// NewNode creates a new filesystem node for a mknod syscall.
- NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error)
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (Inode, error)
// Unlink removes a child dentry from this directory inode.
- Unlink(ctx context.Context, name string, child *Dentry) error
+ Unlink(ctx context.Context, name string, child Inode) error
// RmDir removes an empty child directory from this directory
// inode. Implementations must update the parent directory's link count,
// if required. Implementations are not responsible for checking that child
// is a directory, checking for an empty directory.
- RmDir(ctx context.Context, name string, child *Dentry) error
+ RmDir(ctx context.Context, name string, child Inode) error
// Rename is called on the source directory containing an inode being
// renamed. child should point to the resolved child in the source
- // directory. If Rename replaces a dentry in the destination directory, it
- // should return the replaced dentry or nil otherwise.
+ // directory.
//
// Precondition: Caller must serialize concurrent calls to Rename.
- Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (replaced *Dentry, err error)
-}
+ Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error
-type inodeDynamicLookup interface {
- // Lookup should return an appropriate dentry if name should resolve to a
- // child of this dynamic directory inode. This gives the directory an
- // opportunity on every lookup to resolve additional entries that aren't
- // hashed into the directory. This is only called when the inode is a
- // directory. If the inode is not a directory, or if the directory only
- // contains a static set of children, the implementer can unconditionally
- // return an appropriate error (ENOTDIR and ENOENT respectively).
+ // Lookup should return an appropriate inode if name should resolve to a
+ // child of this directory inode. This gives the directory an opportunity
+ // on every lookup to resolve additional entries. This is only called when
+ // the inode is a directory.
//
- // The child returned by Lookup will be hashed into the VFS dentry tree. Its
- // lifetime can be controlled by the filesystem implementation with an
- // appropriate implementation of Valid.
+ // The child returned by Lookup will be hashed into the VFS dentry tree,
+ // atleast for the duration of the current FS operation.
//
- // Lookup returns the child with an extra reference and the caller owns this
- // reference.
- Lookup(ctx context.Context, name string) (*Dentry, error)
-
- // Valid should return true if this inode is still valid, or needs to
- // be resolved again by a call to Lookup.
- Valid(ctx context.Context) bool
+ // Lookup must return the child with an extra reference whose ownership is
+ // transferred to the dentry that is created to point to that inode. If
+ // Inode.Keep returns false, that new dentry will be dropped at the end of
+ // the current filesystem operation (before returning back to the VFS
+ // layer) if no other ref is picked on that dentry. If Inode.Keep returns
+ // true, then the dentry will be cached into the dentry tree until it is
+ // Unlink'd or RmDir'd.
+ Lookup(ctx context.Context, name string) (Inode, error)
// IterDirents is used to iterate over dynamically created entries. It invokes
// cb on each entry in the directory represented by the Inode.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index e413242dc..82fa19c03 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file."
// RootDentryFn is a generator function for creating the root dentry of a test
// filesystem. See newTestSystem.
-type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
+type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode
// newTestSystem sets up a minimal environment for running a test, including an
// instance of a test filesystem. Tests can control the contents of the
@@ -72,14 +72,11 @@ type file struct {
content string
}
-func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
+func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode {
f := &file{}
f.content = content
f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
-
- d := &kernfs.Dentry{}
- d.Init(f)
- return d
+ return f
}
func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error {
@@ -98,27 +95,23 @@ func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.S
type readonlyDir struct {
readonlyDirRefs
attrs
+ kernfs.InodeAlwaysValid
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNoStatFS
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
locks vfs.FileLocks
-
- dentry kernfs.Dentry
}
-func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &readonlyDir{}
dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
dir.EnableLeakCheck()
- dir.dentry.Init(dir)
-
- dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
-
- return &dir.dentry
+ dir.IncLinks(dir.OrderedChildren.Populate(contents))
+ return dir
}
func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
@@ -131,35 +124,33 @@ func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernf
return fd.VFSFileDescription(), nil
}
-func (d *readonlyDir) DecRef(context.Context) {
- d.readonlyDirRefs.DecRef(d.Destroy)
+func (d *readonlyDir) DecRef(ctx context.Context) {
+ d.readonlyDirRefs.DecRef(func() { d.Destroy(ctx) })
}
type dir struct {
dirRefs
attrs
- kernfs.InodeNoDynamicLookup
+ kernfs.InodeAlwaysValid
kernfs.InodeNotSymlink
- kernfs.OrderedChildren
kernfs.InodeNoStatFS
+ kernfs.InodeTemporary
+ kernfs.OrderedChildren
locks vfs.FileLocks
- fs *filesystem
- dentry kernfs.Dentry
+ fs *filesystem
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &dir{}
dir.fs = fs
dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
dir.EnableLeakCheck()
- dir.dentry.Init(dir)
- dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
-
- return &dir.dentry
+ dir.IncLinks(dir.OrderedChildren.Populate(contents))
+ return dir
}
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
@@ -172,11 +163,11 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry
return fd.VFSFileDescription(), nil
}
-func (d *dir) DecRef(context.Context) {
- d.dirRefs.DecRef(d.Destroy)
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
-func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
dir := d.fs.newDir(creds, opts.Mode, nil)
if err := d.OrderedChildren.Insert(name, dir); err != nil {
@@ -187,7 +178,7 @@ func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*
return dir, nil
}
-func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
f := d.fs.newFile(creds, "")
if err := d.OrderedChildren.Insert(name, f); err != nil {
@@ -197,15 +188,15 @@ func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*
return f, nil
}
-func (*dir) NewLink(context.Context, string, kernfs.Inode) (*kernfs.Dentry, error) {
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
-func (*dir) NewSymlink(context.Context, string, string) (*kernfs.Dentry, error) {
+func (*dir) NewSymlink(context.Context, string, string) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
-func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*kernfs.Dentry, error) {
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (kernfs.Inode, error) {
return nil, syserror.EPERM
}
@@ -213,18 +204,22 @@ func (fsType) Name() string {
return "kernfs"
}
+func (fsType) Release(ctx context.Context) {}
+
func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
fs.VFSFilesystem().Init(vfsObj, &fst, fs)
root := fst.rootFn(creds, fs)
- return fs.VFSFilesystem(), root.VFSDentry(), nil
+ var d kernfs.Dentry
+ d.Init(&fs.Filesystem, root)
+ return fs.VFSFilesystem(), d.VFSDentry(), nil
}
// -------------------- Remainder of the file are test cases --------------------
func TestBasic(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"file1": fs.newFile(creds, staticFileContent),
})
})
@@ -233,8 +228,8 @@ func TestBasic(t *testing.T) {
}
func TestMkdirGetDentry(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"dir1": fs.newDir(creds, 0755, nil),
})
})
@@ -248,8 +243,8 @@ func TestMkdirGetDentry(t *testing.T) {
}
func TestReadStaticFile(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"file1": fs.newFile(creds, staticFileContent),
})
})
@@ -274,8 +269,8 @@ func TestReadStaticFile(t *testing.T) {
}
func TestCreateNewFileInStaticDir(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
"dir1": fs.newDir(creds, 0755, nil),
})
})
@@ -301,7 +296,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
func TestDirFDReadWrite(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
return fs.newReadonlyDir(creds, 0755, nil)
})
defer sys.Destroy()
@@ -325,11 +320,11 @@ func TestDirFDReadWrite(t *testing.T) {
}
func TestDirFDIterDirents(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
- return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
// Fill root with nodes backed by various inode implementations.
"dir1": fs.newReadonlyDir(creds, 0755, nil),
- "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{
"dir3": fs.newDir(creds, 0755, nil),
}),
"file1": fs.newFile(creds, staticFileContent),
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 58a93eaac..934cc6c9e 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -38,13 +38,10 @@ type StaticSymlink struct {
var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) *Dentry {
+func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
inode := &StaticSymlink{}
inode.Init(creds, devMajor, devMinor, ino, target)
-
- d := &Dentry{}
- d.Init(inode)
- return d
+ return inode
}
// Init initializes the instance.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index ea7f073eb..d0ed17b18 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -29,24 +29,22 @@ import (
//
// +stateify savable
type syntheticDirectory struct {
+ InodeAlwaysValid
InodeAttrs
InodeNoStatFS
- InodeNoopRefCount
- InodeNoDynamicLookup
InodeNotSymlink
OrderedChildren
+ syntheticDirectoryRefs
locks vfs.FileLocks
}
var _ Inode = (*syntheticDirectory)(nil)
-func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *Dentry {
+func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode {
inode := &syntheticDirectory{}
inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
- d := &Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
@@ -69,34 +67,46 @@ func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath,
}
// NewFile implements Inode.NewFile.
-func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (Inode, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (Inode, error) {
if !opts.ForSyntheticMountpoint {
return nil, syserror.EPERM
}
- subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
- if err := dir.OrderedChildren.Insert(name, subdird); err != nil {
- subdird.DecRef(ctx)
+ subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ if err := dir.OrderedChildren.Insert(name, subdirI); err != nil {
+ subdirI.DecRef(ctx)
return nil, err
}
- return subdird, nil
+ return subdirI, nil
}
// NewLink implements Inode.NewLink.
-func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) {
+func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (Inode, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*Dentry, error) {
+func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (Inode, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) {
+func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (Inode, error) {
return nil, syserror.EPERM
}
+
+// DecRef implements Inode.DecRef.
+func (dir *syntheticDirectory) DecRef(ctx context.Context) {
+ dir.syntheticDirectoryRefs.DecRef(func() { dir.Destroy(ctx) })
+}
+
+// Keep implements Inode.Keep. This is redundant because inodes will never be
+// created via Lookup and inodes are always valid. Makes sense to return true
+// because these inodes are not temporary and should only be removed on RmDir.
+func (dir *syntheticDirectory) Keep() bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index dfbccd05f..e5f506d2e 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -60,6 +60,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
// FilesystemType.GetFilesystem.
//
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 4e2da4810..e44b79b68 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -39,6 +39,9 @@ func (filesystemType) Name() string {
return "pipefs"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
panic("pipefs.filesystemType.GetFilesystem should never be called")
@@ -165,7 +168,7 @@ func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vf
fs := mnt.Filesystem().Impl().(*filesystem)
inode := newInode(ctx, fs)
var d kernfs.Dentry
- d.Init(inode)
+ d.Init(&fs.Filesystem, inode)
defer d.DecRef(ctx)
return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 05d7948ea..fd70a07de 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -34,13 +34,14 @@ const Name = "proc"
// +stateify savable
type FilesystemType struct{}
-var _ vfs.FilesystemType = (*FilesystemType)(nil)
-
// Name implements vfs.FilesystemType.Name.
func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -73,7 +74,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
cgroups = data.Cgroups
}
- _, dentry := procfs.newTasksInode(k, pidns, cgroups)
+ inode := procfs.newTasksInode(k, pidns, cgroups)
+ var dentry kernfs.Dentry
+ dentry.Init(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
@@ -94,12 +97,9 @@ type dynamicInode interface {
Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
}
-func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
+ return inode
}
// +stateify savable
@@ -114,8 +114,8 @@ func newStaticFile(data string) *staticFile {
return &staticFile{StaticData: vfs.StaticData{Data: data}}
}
-func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
- return kernfs.NewStaticDir(creds, devMajor, devMinor, ino, perm, children, kernfs.GenericDirectoryFDOptions{
+func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
+ return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 47ecd941c..bad2fab4f 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -32,10 +32,11 @@ import (
// +stateify savable
type subtasksInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
subtasksInodeRefs
@@ -49,7 +50,7 @@ type subtasksInode struct {
var _ kernfs.Inode = (*subtasksInode)(nil)
-func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *kernfs.Dentry {
+func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode {
subInode := &subtasksInode{
fs: fs,
task: task,
@@ -62,14 +63,11 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
subInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: subInode, owner: task}
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
-
- return dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
tid, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -82,10 +80,10 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry
if subTask.ThreadGroup() != i.task.ThreadGroup() {
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers), nil
+ return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers)
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
@@ -186,6 +184,6 @@ func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credential
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *subtasksInode) DecRef(context.Context) {
- i.subtasksInodeRefs.DecRef(i.Destroy)
+func (i *subtasksInode) DecRef(ctx context.Context) {
+ i.subtasksInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index a7cd6f57e..b63a4eca0 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -35,8 +35,8 @@ type taskInode struct {
implStatFS
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
taskInodeRefs
@@ -47,41 +47,44 @@ type taskInode struct {
var _ kernfs.Inode = (*taskInode)(nil)
-func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
- // TODO(gvisor.dev/issue/164): Fail with ESRCH if task exited.
- contents := map[string]*kernfs.Dentry{
- "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}),
- "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+
+ contents := map[string]kernfs.Inode{
+ "auxv": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
"comm": fs.newComm(task, fs.NextIno(), 0444),
"cwd": fs.newCwdSymlink(task, fs.NextIno()),
- "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "environ": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
"exe": fs.newExeSymlink(task, fs.NextIno()),
"fd": fs.newFDDirInode(task),
"fdinfo": fs.newFDInfoDirInode(task),
- "gid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
- "io": fs.newTaskOwnedFile(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
- "maps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mapsData{task: task}),
- "mountinfo": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
- "mounts": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountsData{task: task}),
+ "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}),
"net": fs.newTaskNetDir(task),
- "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]kernfs.Inode{
"net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"),
"pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"),
"user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"),
}),
- "oom_score": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newStaticFile("0\n")),
- "oom_score_adj": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
- "smaps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &smapsData{task: task}),
- "stat": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
- "statm": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statmData{task: task}),
- "status": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
- "uid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ "oom_score": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &smapsData{task: task}),
+ "stat": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statmData{task: task}),
+ "status": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers)
}
if len(cgroupControllers) > 0 {
- contents["cgroup"] = fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ contents["cgroup"] = fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
}
taskInode := &taskInode{task: task}
@@ -90,17 +93,15 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := taskInode.OrderedChildren.Populate(dentry, contents)
+ links := taskInode.OrderedChildren.Populate(contents)
taskInode.IncLinks(links)
- return dentry
+ return inode, nil
}
-// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// Valid implements kernfs.Inode.Valid. This inode remains valid as long
// as the task is still running. When it's dead, another tasks with the same
// PID could replace it.
func (i *taskInode) Valid(ctx context.Context) bool {
@@ -124,8 +125,8 @@ func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, v
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *taskInode) DecRef(context.Context) {
- i.taskInodeRefs.DecRef(i.Destroy)
+func (i *taskInode) DecRef(ctx context.Context) {
+ i.taskInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
@@ -141,34 +142,23 @@ type taskOwnedInode struct {
var _ kernfs.Inode = (*taskOwnedInode)(nil)
-func (fs *filesystem) newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
- taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
+ return &taskOwnedInode{Inode: inode, owner: task}
}
-func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
- dir := &kernfs.StaticDirectory{}
-
+func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
- dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, kernfs.GenericDirectoryFDOptions{
- SeekEnd: kernfs.SeekEndZero,
- })
- dir.EnableLeakCheck()
-
- inode := &taskOwnedInode{Inode: dir, owner: task}
- d := &kernfs.Dentry{}
- d.Init(inode)
+ fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero}
+ dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
- dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := dir.OrderedChildren.Populate(d, children)
- dir.IncLinks(links)
+ return &taskOwnedInode{Inode: dir, owner: task}
+}
- return d
+func (i *taskOwnedInode) Valid(ctx context.Context) bool {
+ return i.owner.ExitState() != kernel.TaskExitDead && i.Inode.Valid(ctx)
}
// Stat implements kernfs.Inode.Stat.
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 0866cea2b..2c80ac5c2 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -63,7 +63,7 @@ type fdDir struct {
produceSymlink bool
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
@@ -109,16 +109,17 @@ type fdDirInode struct {
fdDir
fdDirInodeRefs
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
}
var _ kernfs.Inode = (*fdDirInode)(nil)
-func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
inode := &fdDirInode{
fdDir: fdDir{
fs: fs,
@@ -128,16 +129,17 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
}
inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
-
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ return inode
+}
- return dentry
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
+func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *fdDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -183,8 +185,8 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *fdDirInode) DecRef(context.Context) {
- i.fdDirInodeRefs.DecRef(i.Destroy)
+func (i *fdDirInode) DecRef(ctx context.Context) {
+ i.fdDirInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
@@ -202,16 +204,13 @@ type fdSymlink struct {
var _ kernfs.Inode = (*fdSymlink)(nil)
-func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kernfs.Inode {
inode := &fdSymlink{
task: task,
fd: fd,
}
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
@@ -236,6 +235,11 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen
return vd, "", nil
}
+// Valid implements kernfs.Inode.Valid.
+func (s *fdSymlink) Valid(ctx context.Context) bool {
+ return taskFDExists(ctx, s.task, s.fd)
+}
+
// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
//
// +stateify savable
@@ -243,16 +247,17 @@ type fdInfoDirInode struct {
fdDir
fdInfoDirInodeRefs
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary
kernfs.OrderedChildren
}
var _ kernfs.Inode = (*fdInfoDirInode)(nil)
-func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
inode := &fdInfoDirInode{
fdDir: fdDir{
fs: fs,
@@ -261,16 +266,12 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
}
inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
-
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
-
- return dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -283,7 +284,12 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentr
task: i.task,
fd: fd,
}
- return i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data), nil
+ return i.fs.newTaskOwnedInode(i.task, i.fs.NextIno(), 0444, data), nil
+}
+
+// IterDirents implements Inode.IterDirents.
+func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
}
// Open implements kernfs.Inode.Open.
@@ -298,8 +304,8 @@ func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *ker
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *fdInfoDirInode) DecRef(context.Context) {
- i.fdInfoDirInodeRefs.DecRef(i.Destroy)
+func (i *fdInfoDirInode) DecRef(ctx context.Context) {
+ i.fdInfoDirInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
@@ -328,3 +334,8 @@ func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "flags:\t0%o\n", flags)
return nil
}
+
+// Valid implements kernfs.Inode.Valid.
+func (d *fdInfoData) Valid(ctx context.Context) bool {
+ return taskFDExists(ctx, d.task, d.fd)
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 3fbf081a6..79f8b7e9f 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -247,13 +247,10 @@ type commInode struct {
task *kernel.Task
}
-func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
inode := &commInode{task: task}
inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
@@ -658,13 +655,10 @@ type exeSymlink struct {
var _ kernfs.Inode = (*exeSymlink)(nil)
-func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &exeSymlink{task: task}
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
// Readlink implements kernfs.Inode.Readlink.
@@ -737,13 +731,10 @@ type cwdSymlink struct {
var _ kernfs.Inode = (*cwdSymlink)(nil)
-func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &cwdSymlink{task: task}
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+ return inode
}
// Readlink implements kernfs.Inode.Readlink.
@@ -851,7 +842,7 @@ type namespaceSymlink struct {
task *kernel.Task
}
-func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) kernfs.Inode {
// Namespace symlinks should contain the namespace name and the inode number
// for the namespace instance, so for example user:[123456]. We currently fake
// the inode number by sticking the symlink inode in its place.
@@ -862,9 +853,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
+ return taskInode
}
// Readlink implements kernfs.Inode.Readlink.
@@ -882,11 +871,12 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
}
// Create a synthetic inode to represent the namespace.
+ fs := mnt.Filesystem().Impl().(*filesystem)
dentry := &kernfs.Dentry{}
- dentry.Init(&namespaceInode{})
+ dentry.Init(&fs.Filesystem, &namespaceInode{})
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
- vd.IncRef()
- dentry.DecRef(ctx)
+ // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1.
+ mnt.IncRef()
return vd, "", nil
}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index e7f748655..3425e8698 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -37,12 +37,12 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
+func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode {
k := task.Kernel()
pidns := task.PIDNamespace()
root := auth.NewRootCredentials(pidns.UserNamespace())
- var contents map[string]*kernfs.Dentry
+ var contents map[string]kernfs.Inode
if stack := task.NetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
@@ -56,34 +56,34 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
// TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
// network namespace.
- contents = map[string]*kernfs.Dentry{
- "dev": fs.newDentry(root, fs.NextIno(), 0444, &netDevData{stack: stack}),
- "snmp": fs.newDentry(root, fs.NextIno(), 0444, &netSnmpData{stack: stack}),
+ contents = map[string]kernfs.Inode{
+ "dev": fs.newInode(root, 0444, &netDevData{stack: stack}),
+ "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, if the file contains a header the stub is just the header
// otherwise it is an empty file.
- "arp": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(arp)),
- "netlink": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(netlink)),
- "netstat": fs.newDentry(root, fs.NextIno(), 0444, &netStatData{}),
- "packet": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(packet)),
- "protocols": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(protocols)),
+ "arp": fs.newInode(root, 0444, newStaticFile(arp)),
+ "netlink": fs.newInode(root, 0444, newStaticFile(netlink)),
+ "netstat": fs.newInode(root, 0444, &netStatData{}),
+ "packet": fs.newInode(root, 0444, newStaticFile(packet)),
+ "protocols": fs.newInode(root, 0444, newStaticFile(protocols)),
// Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
// high res timer ticks per sec (ClockGetres returns 1ns resolution).
- "psched": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(psched)),
- "ptype": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(ptype)),
- "route": fs.newDentry(root, fs.NextIno(), 0444, &netRouteData{stack: stack}),
- "tcp": fs.newDentry(root, fs.NextIno(), 0444, &netTCPData{kernel: k}),
- "udp": fs.newDentry(root, fs.NextIno(), 0444, &netUDPData{kernel: k}),
- "unix": fs.newDentry(root, fs.NextIno(), 0444, &netUnixData{kernel: k}),
+ "psched": fs.newInode(root, 0444, newStaticFile(psched)),
+ "ptype": fs.newInode(root, 0444, newStaticFile(ptype)),
+ "route": fs.newInode(root, 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}),
+ "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}),
+ "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}),
}
if stack.SupportsIPv6() {
- contents["if_inet6"] = fs.newDentry(root, fs.NextIno(), 0444, &ifinet6{stack: stack})
- contents["ipv6_route"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(""))
- contents["tcp6"] = fs.newDentry(root, fs.NextIno(), 0444, &netTCP6Data{kernel: k})
- contents["udp6"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(upd6))
+ contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6))
}
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index d8f5dd509..3259c3732 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -38,10 +38,11 @@ const (
// +stateify savable
type tasksInode struct {
implStatFS
- kernfs.AlwaysValid
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
kernfs.OrderedChildren
tasksInodeRefs
@@ -52,8 +53,6 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- selfSymlink *kernfs.Dentry
- threadSelfSymlink *kernfs.Dentry
// cgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
@@ -63,52 +62,53 @@ type tasksInode struct {
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
+func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
- contents := map[string]*kernfs.Dentry{
- "cpuinfo": fs.newDentry(root, fs.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
- "filesystems": fs.newDentry(root, fs.NextIno(), 0444, &filesystemsData{}),
- "loadavg": fs.newDentry(root, fs.NextIno(), 0444, &loadavgData{}),
+ contents := map[string]kernfs.Inode{
+ "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newInode(root, 0444, &filesystemsData{}),
+ "loadavg": fs.newInode(root, 0444, &loadavgData{}),
"sys": fs.newSysDir(root, k),
- "meminfo": fs.newDentry(root, fs.NextIno(), 0444, &meminfoData{}),
+ "meminfo": fs.newInode(root, 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
"net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
- "stat": fs.newDentry(root, fs.NextIno(), 0444, &statData{}),
- "uptime": fs.newDentry(root, fs.NextIno(), 0444, &uptimeData{}),
- "version": fs.newDentry(root, fs.NextIno(), 0444, &versionData{}),
+ "stat": fs.newInode(root, 0444, &statData{}),
+ "uptime": fs.newInode(root, 0444, &uptimeData{}),
+ "version": fs.newInode(root, 0444, &versionData{}),
}
inode := &tasksInode{
pidns: pidns,
fs: fs,
- selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns),
- threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns),
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
- dentry := &kernfs.Dentry{}
- dentry.Init(inode)
-
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- links := inode.OrderedChildren.Populate(dentry, contents)
+ links := inode.OrderedChildren.Populate(contents)
inode.IncLinks(links)
- return inode, dentry
+ return inode
}
-// Lookup implements kernfs.inodeDynamicLookup.Lookup.
-func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
- // Try to lookup a corresponding task.
+// Lookup implements kernfs.inodeDirectory.Lookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
+ // Check if a static entry was looked up.
+ if d, err := i.OrderedChildren.Lookup(ctx, name); err == nil {
+ return d, nil
+ }
+
+ // Not a static entry. Try to lookup a corresponding task.
tid, err := strconv.ParseUint(name, 10, 64)
if err != nil {
+ root := auth.NewRootCredentials(i.pidns.UserNamespace())
// If it failed to parse, check if it's one of the special handled files.
switch name {
case selfName:
- return i.selfSymlink, nil
+ return i.newSelfSymlink(root), nil
case threadSelfName:
- return i.threadSelfSymlink, nil
+ return i.newThreadSelfSymlink(root), nil
}
return nil, syserror.ENOENT
}
@@ -118,10 +118,10 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, e
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers), nil
+ return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers)
}
-// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
+// IterDirents implements kernfs.inodeDirectory.IterDirents.
func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
@@ -229,8 +229,8 @@ func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.St
}
// DecRef implements kernfs.Inode.DecRef.
-func (i *tasksInode) DecRef(context.Context) {
- i.tasksInodeRefs.DecRef(i.Destroy)
+func (i *tasksInode) DecRef(ctx context.Context) {
+ i.tasksInodeRefs.DecRef(func() { i.Destroy(ctx) })
}
// staticFileSetStat implements a special static file that allows inode
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index f268c59b0..07c27cdd9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -43,13 +43,10 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
- inode := &selfSymlink{pidns: pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+ inode := &selfSymlink{pidns: i.pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ return inode
}
func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
@@ -87,13 +84,10 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
- inode := &threadSelfSymlink{pidns: pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
-
- d := &kernfs.Dentry{}
- d.Init(inode)
- return d
+func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+ inode := &threadSelfSymlink{pidns: i.pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ return inode
}
func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 3312b0418..95420368d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -40,93 +40,93 @@ const (
)
// newSysDir returns the dentry corresponding to /proc/sys directory.
-func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
- return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "kernel": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}),
- "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)),
+func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ return fs.newStaticDir(root, map[string]kernfs.Inode{
+ "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "hostname": fs.newInode(root, 0444, &hostnameData{}),
+ "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)),
}),
- "vm": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}),
- "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "vm": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")),
}),
"net": fs.newSysNetDir(root, k),
})
}
// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
-func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
- var contents map[string]*kernfs.Dentry
+func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ var contents map[string]kernfs.Inode
// TODO(gvisor.dev/issue/1833): Support for using the network stack in the
// network namespace of the calling process.
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
- contents = map[string]*kernfs.Dentry{
- "ipv4": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}),
+ contents = map[string]kernfs.Inode{
+ "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("16000 65535")),
- "ip_local_reserved_ports": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "ipfrag_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("30")),
- "ip_nonlocal_bind": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "ip_no_pmtu_disc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")),
+ "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")),
// tcp_allowed_congestion_control tell the user what they are able to
// do as an unprivledged process so we leave it empty.
- "tcp_allowed_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "tcp_available_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
- "tcp_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
// Many of the following stub files are features netstack doesn't
// support. The unsupported features return "0" to indicate they are
// disabled.
- "tcp_base_mss": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1280")),
- "tcp_dsack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_early_retrans": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fastopen": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_fastopen_key": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
- "tcp_invalid_ratelimit": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_intvl": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_probes": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_keepalive_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("7200")),
- "tcp_mtu_probing": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_no_metrics_save": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_probe_interval": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_probe_threshold": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "tcp_retries1": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
- "tcp_retries2": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("15")),
- "tcp_rfc1337": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_slow_start_after_idle": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
- "tcp_synack_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
- "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
- "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")),
}),
- "core": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")),
- "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")),
- "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
- "optmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
- "rmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "rmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "somaxconn": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("128")),
- "wmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
- "wmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "core": fs.newStaticDir(root, map[string]kernfs.Inode{
+ "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newInode(root, 0444, newStaticFile("10")),
+ "message_cost": fs.newInode(root, 0444, newStaticFile("5")),
+ "optmem_max": fs.newInode(root, 0444, newStaticFile("0")),
+ "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
+ "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
+ "somaxconn": fs.newInode(root, 0444, newStaticFile("128")),
+ "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
+ "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
}),
}
}
- return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents)
+ return fs.newStaticDir(root, contents)
}
// mmapMinAddrData implements vfs.DynamicBytesSource for
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 6975af5a7..2582ababd 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -109,9 +109,12 @@ func setup(t *testing.T) *testutil.System {
if err != nil {
t.Fatalf("NewMountNamespace(): %v", err)
}
+ root := mntns.Root()
+ root.IncRef()
+ defer root.DecRef(ctx)
pop := &vfs.PathOperation{
- Root: mntns.Root(),
- Start: mntns.Root(),
+ Root: root,
+ Start: root,
Path: fspath.Parse("/proc"),
}
if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
@@ -119,8 +122,8 @@ func setup(t *testing.T) *testutil.System {
}
pop = &vfs.PathOperation{
- Root: mntns.Root(),
- Start: mntns.Root(),
+ Root: root,
+ Start: root,
Path: fspath.Parse("/proc"),
}
mntOpts := &vfs.MountOptions{
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
index 067c1657f..adb610213 100644
--- a/pkg/sentry/fsimpl/signalfd/BUILD
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/kernel",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index bf11b425a..10f1452ef 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -16,7 +16,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -95,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -105,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index 29e5371d6..cf91ea36c 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -46,6 +46,9 @@ func (filesystemType) Name() string {
return "sockfs"
}
+// Release implements vfs.FilesystemType.Release.
+func (filesystemType) Release(ctx context.Context) {}
+
// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -114,6 +117,6 @@ func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
d := &kernfs.Dentry{}
- d.Init(i)
+ d.Init(&fs.Filesystem, i)
return d.VFSDentry()
}
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index b75d70ae6..94366d429 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -27,12 +27,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) *kernfs.Dentry {
+func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
k := &kcovInode{}
k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
- d := &kernfs.Dentry{}
- d.Init(k)
- return d
+ return k
}
// kcovInode implements kernfs.Inode.
@@ -104,7 +102,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
func (fd *kcovFD) Release(ctx context.Context) {
// kcov instances have reference counts in Linux, but this seems sufficient
// for our purposes.
- fd.kcov.Reset()
+ fd.kcov.Clear()
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1568c581f..1ad679830 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -52,6 +52,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
@@ -64,15 +67,15 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
- root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"block": fs.newDir(creds, defaultSysDirMode, nil),
"bus": fs.newDir(creds, defaultSysDirMode, nil),
- "class": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"power_supply": fs.newDir(creds, defaultSysDirMode, nil),
}),
"dev": fs.newDir(creds, defaultSysDirMode, nil),
- "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
- "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
"cpu": cpuDir(ctx, fs, creds),
}),
}),
@@ -82,13 +85,15 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
"module": fs.newDir(creds, defaultSysDirMode, nil),
"power": fs.newDir(creds, defaultSysDirMode, nil),
})
- return fs.VFSFilesystem(), root.VFSDentry(), nil
+ var rootD kernfs.Dentry
+ rootD.Init(&fs.Filesystem, root)
+ return fs.VFSFilesystem(), rootD.VFSDentry(), nil
}
-func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry {
+func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
- children := map[string]*kernfs.Dentry{
+ children := map[string]kernfs.Inode{
"online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
"possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
"present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
@@ -99,14 +104,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernf
return fs.newDir(creds, defaultSysDirMode, children)
}
-func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry {
+func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
// If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs
// should be mounted at debug/, but for our purposes, it is sufficient to
// keep it in sys.
- var children map[string]*kernfs.Dentry
+ var children map[string]kernfs.Inode
if coverage.KcovAvailable() {
- children = map[string]*kernfs.Dentry{
- "debug": fs.newDir(creds, linux.FileMode(0700), map[string]*kernfs.Dentry{
+ children = map[string]kernfs.Inode{
+ "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{
"kcov": fs.newKcovFile(ctx, creds),
}),
}
@@ -125,27 +130,23 @@ func (fs *filesystem) Release(ctx context.Context) {
// +stateify savable
type dir struct {
dirRefs
+ kernfs.InodeAlwaysValid
kernfs.InodeAttrs
- kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeTemporary
kernfs.OrderedChildren
locks vfs.FileLocks
-
- dentry kernfs.Dentry
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
d := &dir{}
d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
d.EnableLeakCheck()
- d.dentry.Init(d)
-
- d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents))
-
- return &d.dentry
+ d.IncLinks(d.OrderedChildren.Populate(contents))
+ return d
}
// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
@@ -165,8 +166,8 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry
}
// DecRef implements kernfs.Inode.DecRef.
-func (d *dir) DecRef(context.Context) {
- d.dirRefs.DecRef(d.Destroy)
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
// StatFS implements kernfs.Inode.StatFS.
@@ -190,12 +191,10 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
c := &cpuFile{maxCores: maxCores}
c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
- d := &kernfs.Dentry{}
- d.Init(c)
- return d
+ return c
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 568132121..1a8525b06 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -46,16 +46,18 @@ type System struct {
// NewSystem constructs a System.
//
-// Precondition: Caller must hold a reference on MntNs, whose ownership
+// Precondition: Caller must hold a reference on mns, whose ownership
// is transferred to the new System.
func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
+ root := mns.Root()
+ root.IncRef()
s := &System{
t: t,
Ctx: ctx,
Creds: auth.CredentialsFromContext(ctx),
VFS: v,
MntNs: mns,
- Root: mns.Root(),
+ Root: root,
}
return s
}
@@ -254,10 +256,10 @@ func (d *DirentCollector) Contains(name string, typ uint8) error {
defer d.mu.Unlock()
dirent, ok := d.dirents[name]
if !ok {
- return fmt.Errorf("No dirent named %q found", name)
+ return fmt.Errorf("no dirent named %q found", name)
}
if dirent.Type != typ {
- return fmt.Errorf("Dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
+ return fmt.Errorf("dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
}
return nil
}
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 5209a17af..3cc63e732 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -193,6 +193,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
// Create nested directories with given depth.
root := mntns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
vd := root
vd.IncRef()
@@ -387,6 +388,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
// Create the mount point.
root := mntns.Root()
+ root.IncRef()
defer root.DecRef(ctx)
pop := vfs.PathOperation{
Root: root,
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index be29a2363..2f856ce36 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -165,6 +165,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create the pipe.
root := mntns.Root()
+ root.IncRef()
pop := vfs.PathOperation{
Root: root,
Start: root,
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index a199eb33d..ce4e3eda7 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -293,7 +293,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.
optional.End = pgend
}
- cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ cerr := rf.data.Fill(ctx, required, optional, rf.size, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
// Newly-allocated pages are zeroed, so we don't need to do anything.
return dsts.NumBytes(), nil
})
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index cefec8fde..e2a0aac69 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -74,6 +74,8 @@ type filesystem struct {
mu sync.RWMutex `state:"nosave"`
nextInoMinusOne uint64 // accessed using atomic memory operations
+
+ root *dentry
}
// Name implements vfs.FilesystemType.Name.
@@ -81,6 +83,9 @@ func (FilesystemType) Name() string {
return Name
}
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// FilesystemOpts is used to pass configuration data to tmpfs.
//
// +stateify savable
@@ -194,6 +199,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.vfsfs.DecRef(ctx)
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
}
+ fs.root = root
return &fs.vfsfs, &root.vfsd, nil
}
@@ -205,6 +211,37 @@ func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *au
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.mu.Lock()
+ if fs.root.inode.isDir() {
+ fs.root.releaseChildrenLocked(ctx)
+ }
+ fs.mu.Unlock()
+}
+
+// releaseChildrenLocked is called on the mount point by filesystem.Release() to
+// destroy all objects in the mount. It performs a depth-first walk of the
+// filesystem and "unlinks" everything by decrementing link counts
+// appropriately. There should be no open file descriptors when this is called,
+// so each inode should only have one outstanding reference that is removed once
+// its link count hits zero.
+//
+// Note that we do not update filesystem state precisely while tearing down (for
+// instance, the child maps are ignored)--we only care to remove all remaining
+// references so that every filesystem object gets destroyed. Also note that we
+// do not need to trigger DecRef on the mount point itself or any child mount;
+// these are taken care of by the destructor of the enclosing MountNamespace.
+//
+// Precondition: filesystem.mu is held.
+func (d *dentry) releaseChildrenLocked(ctx context.Context) {
+ dir := d.inode.impl.(*directory)
+ for _, child := range dir.childMap {
+ if child.inode.isDir() {
+ child.releaseChildrenLocked(ctx)
+ child.inode.decLinksLocked(ctx) // link for child/.
+ dir.inode.decLinksLocked(ctx) // link for child/..
+ }
+ child.inode.decLinksLocked(ctx) // link for child
+ }
}
// immutable
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
index 99c8e3c0f..fc5323abc 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -46,6 +46,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
}
root := mntns.Root()
+ root.IncRef()
return vfsObj, root, func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index bc8e38431..0ca750281 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
licenses(["notice"])
@@ -26,3 +26,22 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "verity_test",
+ srcs = [
+ "verity_test.go",
+ ],
+ library = ":verity",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 26b117ca4..3b3c8725f 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -20,6 +20,7 @@ import (
"io"
"strconv"
"strings"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -155,11 +156,10 @@ afterSymlink:
return child, nil
}
-// verifyChild verifies the root hash of child against the already verified
-// root hash of the parent to ensure the child is expected. verifyChild
-// triggers a sentry panic if unexpected modifications to the file system are
-// detected. In noCrashOnVerificationFailure mode it returns a syserror
-// instead.
+// verifyChild verifies the hash of child against the already verified hash of
+// the parent to ensure the child is expected. verifyChild triggers a sentry
+// panic if unexpected modifications to the file system are detected. In
+// noCrashOnVerificationFailure mode it returns a syserror instead.
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
// TODO(b/166474175): Investigate all possible errors returned in this
// function, and make sure we differentiate all errors that indicate unexpected
@@ -174,12 +174,12 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, err
}
- verityMu.RLock()
- defer verityMu.RUnlock()
+ fs.verityMu.RLock()
+ defer fs.verityMu.RUnlock()
// Read the offset of the child from the extended attributes of the
// corresponding Merkle tree file.
- // This is the offset of the root hash for child in its parent's Merkle
- // tree file.
+ // This is the offset of the hash for child in its parent's Merkle tree
+ // file.
off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{
Root: child.lowerMerkleVD,
Start: child.lowerMerkleVD,
@@ -204,7 +204,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
- // Open parent Merkle tree file to read and verify child's root hash.
+ // Open parent Merkle tree file to read and verify child's hash.
parentMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: parent.lowerMerkleVD,
Start: parent.lowerMerkleVD,
@@ -223,7 +223,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// dataSize is the size of raw data for the Merkle tree. For a file,
// dataSize is the size of the whole file. For a directory, dataSize is
- // the size of all its children's root hashes.
+ // the size of all its children's hashes.
dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
Size: sizeOfStringInt32,
@@ -251,32 +251,152 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Ctx: ctx,
}
+ parentStat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parent.lowerVD,
+ Start: parent.lowerVD,
+ }, &vfs.StatOptions{})
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return nil, err
+ }
+
// Since we are verifying against a directory Merkle tree, buf should
- // contain the root hash of the children in the parent Merkle tree when
+ // contain the hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
+ if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ ReadOffset: int64(offset),
+ ReadSize: int64(merkletree.DigestSize()),
+ Expected: parent.hash,
+ DataAndTreeInSameFile: true,
+ }); err != nil && err != io.EOF {
return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
- // Cache child root hash when it's verified the first time.
- if len(child.rootHash) == 0 {
- child.rootHash = buf.Bytes()
+ // Cache child hash when it's verified the first time.
+ if len(child.hash) == 0 {
+ child.hash = buf.Bytes()
}
return child, nil
}
+// verifyStat verifies the stat against the verified hash. The mode/uid/gid of
+// the file is cached after verified.
+func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+
+ // Get the path to the child dentry. This is only used to provide path
+ // information in failure case.
+ childPath, err := vfsObj.PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD)
+ if err != nil {
+ return err
+ }
+
+ fs.verityMu.RLock()
+ defer fs.verityMu.RUnlock()
+
+ fd, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err == syserror.ENOENT {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ if err == syserror.ENODATA {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ size, err := strconv.Atoi(merkleSize)
+ if err != nil {
+ return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ }
+
+ fdReader := vfs.FileReadWriteSeeker{
+ FD: fd,
+ Ctx: ctx,
+ }
+
+ var buf bytes.Buffer
+ params := &merkletree.VerifyParams{
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ ReadOffset: 0,
+ // Set read size to 0 so only the metadata is verified.
+ ReadSize: 0,
+ Expected: d.hash,
+ DataAndTreeInSameFile: false,
+ }
+ if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
+ params.DataAndTreeInSameFile = true
+ }
+
+ if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
+ return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ }
+ d.mode = uint32(stat.Mode)
+ d.uid = stat.UID
+ d.gid = stat.GID
+ return nil
+}
+
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
// If enabling verification on files/directories is not allowed
// during runtime, all cached children are already verified. If
// runtime enable is allowed and the parent directory is
- // enabled, we should verify the child root hash here because
- // it may be cached before enabled.
- if fs.allowRuntimeEnable && len(parent.rootHash) != 0 {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
- return nil, err
+ // enabled, we should verify the child hash here because it may
+ // be cached before enabled.
+ if fs.allowRuntimeEnable {
+ if isEnabled(parent) {
+ if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ return nil, err
+ }
+ }
+ if isEnabled(child) {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.lowerVD,
+ Start: child.lowerVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ return nil, err
+ }
}
}
return child, nil
@@ -360,9 +480,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// file open and ready to use.
// This may cause empty and unused Merkle tree files in
// allowRuntimeEnable mode, if they are never enabled. This
- // does not affect verification, as we rely on cached root hash
- // to decide whether to perform verification, not the existence
- // of the Merkle tree file. Also, those Merkle tree files are
+ // does not affect verification, as we rely on cached hash to
+ // decide whether to perform verification, not the existence of
+ // the Merkle tree file. Also, those Merkle tree files are
// always hidden and cannot be accessed by verity fs users.
if fs.allowRuntimeEnable {
childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
@@ -426,20 +546,25 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
child.parent = parent
child.name = name
- // TODO(b/162788573): Verify child metadata.
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
- // Verify child root hash. This should always be performed unless in
+ // Verify child hash. This should always be performed unless in
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
- if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) {
+ if isEnabled(parent) {
if _, err := fs.verifyChild(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
+ if isEnabled(child) {
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ child.destroyLocked(ctx)
+ return nil, err
+ }
+ }
return child, nil
}
@@ -693,22 +818,24 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// be called if a verity FD is created successfully.
defer merkleWriter.DecRef(ctx)
- parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.parent.lowerMerkleVD,
- Start: d.parent.lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_WRONLY | linux.O_APPEND,
- })
- if err != nil {
- if err == syserror.ENOENT {
- parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ if d.parent != nil {
+ parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.parent.lowerMerkleVD,
+ Start: d.parent.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ }
+ return nil, err
}
- return nil, err
+ // parentMerkleWriter is cleaned up if any error occurs. IncRef
+ // will be called if a verity FD is created successfully.
+ defer parentMerkleWriter.DecRef(ctx)
}
- // parentMerkleWriter is cleaned up if any error occurs. IncRef
- // will be called if a verity FD is created successfully.
- defer parentMerkleWriter.DecRef(ctx)
}
fd := &fileDescription{
@@ -769,6 +896,8 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
// StatAt implements vfs.FilesystemImpl.StatAt.
+// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should
+// be verified.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
@@ -786,6 +915,11 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(d) {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 3129f290d..70034280b 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -49,12 +49,12 @@ const Name = "verity"
const merklePrefix = ".merkle.verity."
// merkleoffsetInParentXattr is the extended attribute name specifying the
-// offset of child root hash in its parent's Merkle tree.
+// offset of child hash in its parent's Merkle tree.
const merkleOffsetInParentXattr = "user.merkle.offset"
// merkleSizeXattr is the extended attribute name specifying the size of data
// hashed by the corresponding Merkle tree. For a file, it's the size of the
-// whole file. For a directory, it's the size of all its children's root hashes.
+// whole file. For a directory, it's the size of all its children's hashes.
const merkleSizeXattr = "user.merkle.size"
// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
@@ -68,11 +68,6 @@ const sizeOfStringInt32 = 10
// flag.
var noCrashOnVerificationFailure bool
-// verityMu synchronizes enabling verity files, protects files or directories
-// from being enabled by different threads simultaneously. It also ensures that
-// verity does not access files that are being enabled.
-var verityMu sync.RWMutex
-
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
@@ -106,6 +101,17 @@ type filesystem struct {
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
renameMu sync.RWMutex `state:"nosave"`
+
+ // verityMu synchronizes enabling verity files, protects files or
+ // directories from being enabled by different threads simultaneously.
+ // It also ensures that verity does not access files that are being
+ // enabled.
+ //
+ // Also, the directory Merkle trees depends on the generated trees of
+ // its children. So they shouldn't be enabled the same time. This lock
+ // is for the whole file system to ensure that no more than one file is
+ // enabled the same time.
+ verityMu sync.RWMutex
}
// InternalFilesystemOptions may be passed as
@@ -142,6 +148,17 @@ func (FilesystemType) Name() string {
return Name
}
+// isEnabled checks whether the target is enabled with verity features. It
+// should always be true if runtime enable is not allowed. In runtime enable
+// mode, it returns true if the target has been enabled with
+// ioctl(FS_IOC_ENABLE_VERITY).
+func isEnabled(d *dentry) bool {
+ return !d.fs.allowRuntimeEnable || len(d.hash) != 0
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
@@ -245,13 +262,18 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
- // TODO(b/162788573): Verify Metadata.
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
+ d.hash = make([]byte, len(iopts.RootHash))
- d.rootHash = make([]byte, len(iopts.RootHash))
- copy(d.rootHash, iopts.RootHash)
+ if !fs.allowRuntimeEnable {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return nil, nil, err
+ }
+ }
+
+ copy(d.hash, iopts.RootHash)
d.vfsd.Init(d)
fs.rootDentry = d
@@ -303,8 +325,8 @@ type dentry struct {
// in the underlying file system.
lowerMerkleVD vfs.VirtualDentry
- // rootHash is the rootHash for the current file or directory.
- rootHash []byte
+ // hash is the calculated hash for the current file or directory.
+ hash []byte
}
// newDentry creates a new dentry representing the given verity file. The
@@ -488,6 +510,11 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(fd.d) {
+ if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
@@ -498,11 +525,11 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// generateMerkle generates a Merkle tree file for fd. If fd points to a file
-// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root
-// hash of the generated Merkle tree and the data size is returned.
-// If fd points to a regular file, the data is the content of the file. If fd
-// points to a directory, the data is all root hahes of its children, written
-// to the Merkle tree file.
+// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
+// of the generated Merkle tree and the data size is returned. If fd points to
+// a regular file, the data is the content of the file. If fd points to a
+// directory, the data is all hahes of its children, written to the Merkle tree
+// file.
func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
fdReader := vfs.FileReadWriteSeeker{
FD: fd.lowerFD,
@@ -516,8 +543,10 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
FD: fd.merkleWriter,
Ctx: ctx,
}
- var rootHash []byte
- var dataSize uint64
+ params := &merkletree.GenerateParams{
+ TreeReader: &merkleReader,
+ TreeWriter: &merkleWriter,
+ }
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
case linux.S_IFREG:
@@ -528,75 +557,90 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
if err != nil {
return nil, 0, err
}
- dataSize = stat.Size
- rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */)
- if err != nil {
- return nil, 0, err
- }
+ params.File = &fdReader
+ params.Size = int64(stat.Size)
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = false
case linux.S_IFDIR:
- // For a directory, generate a Merkle tree based on the root
- // hashes of its children that has already been written to the
- // Merkle tree file.
+ // For a directory, generate a Merkle tree based on the hashes
+ // of its children that has already been written to the Merkle
+ // tree file.
merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{})
if err != nil {
return nil, 0, err
}
- dataSize = merkleStat.Size
- rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */)
+ params.Size = int64(merkleStat.Size)
+
+ stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{})
if err != nil {
return nil, 0, err
}
+
+ params.File = &merkleReader
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = true
default:
// TODO(b/167728857): Investigate whether and how we should
// enable other types of file.
return nil, 0, syserror.EINVAL
}
- return rootHash, dataSize, nil
+ hash, err := merkletree.Generate(params)
+ return hash, uint64(params.Size), err
}
// enableVerity enables verity features on fd by generating a Merkle tree file
-// and stores its root hash in its parent directory's Merkle tree.
-func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+// and stores its hash in its parent directory's Merkle tree.
+func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) {
if !fd.d.fs.allowRuntimeEnable {
return 0, syserror.EPERM
}
- // Lock to prevent other threads performing enable or access the file
- // while it's being enabled.
- verityMu.Lock()
- defer verityMu.Unlock()
+ fd.d.fs.verityMu.Lock()
+ defer fd.d.fs.verityMu.Unlock()
- if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil {
+ // In allowRuntimeEnable mode, the underlying fd and read/write fd for
+ // the Merkle tree file should have all been initialized. For any file
+ // or directory other than the root, the parent Merkle tree file should
+ // have also been initialized.
+ if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
}
- rootHash, dataSize, err := fd.generateMerkle(ctx)
+ hash, dataSize, err := fd.generateMerkle(ctx)
if err != nil {
return 0, err
}
- stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- return 0, err
- }
+ if fd.parentMerkleWriter != nil {
+ stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
- // Write the root hash of fd to the parent directory's Merkle tree
- // file, as it should be part of the parent Merkle tree data.
- // parentMerkleWriter is open with O_APPEND, so it should write
- // directly to the end of the file.
- if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil {
- return 0, err
- }
+ // Write the hash of fd to the parent directory's Merkle tree
+ // file, as it should be part of the parent Merkle tree data.
+ // parentMerkleWriter is open with O_APPEND, so it should write
+ // directly to the end of the file.
+ if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(hash), vfs.WriteOptions{}); err != nil {
+ return 0, err
+ }
- // Record the offset of the root hash of fd in parent directory's
- // Merkle tree file.
- if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
- Name: merkleOffsetInParentXattr,
- Value: strconv.Itoa(int(stat.Size)),
- }); err != nil {
- return 0, err
+ // Record the offset of the hash of fd in parent directory's
+ // Merkle tree file.
+ if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
+ Name: merkleOffsetInParentXattr,
+ Value: strconv.Itoa(int(stat.Size)),
+ }); err != nil {
+ return 0, err
+ }
}
// Record the size of the data being hashed for fd.
@@ -606,22 +650,59 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
}); err != nil {
return 0, err
}
- fd.d.rootHash = append(fd.d.rootHash, rootHash...)
+ fd.d.hash = append(fd.d.hash, hash...)
return 0, nil
}
-func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+// measureVerity returns the hash of fd, saved in verityDigest.
+func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ var metadata linux.DigestMetadata
+
+ // If allowRuntimeEnable is true, an empty fd.d.hash indicates that
+ // verity is not enabled for the file. If allowRuntimeEnable is false,
+ // this is an integrity violation because all files should have verity
+ // enabled, in which case fd.d.hash should be set.
+ if len(fd.d.hash) == 0 {
+ if fd.d.fs.allowRuntimeEnable {
+ return 0, syserror.ENODATA
+ }
+ return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found")
+ }
+
+ // The first part of VerityDigest is the metadata.
+ if _, err := metadata.CopyIn(t, verityDigest); err != nil {
+ return 0, err
+ }
+ if metadata.DigestSize < uint16(len(fd.d.hash)) {
+ return 0, syserror.EOVERFLOW
+ }
+
+ // Populate the output digest size, since DigestSize is both input and
+ // output.
+ metadata.DigestSize = uint16(len(fd.d.hash))
+
+ // First copy the metadata.
+ if _, err := metadata.CopyOut(t, verityDigest); err != nil {
+ return 0, err
+ }
+
+ // Now copy the root hash bytes to the memory after metadata.
+ _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash)
+ return 0, err
+}
+
+func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) {
f := int32(0)
- // All enabled files should store a root hash. This flag is not settable
- // via FS_IOC_SETFLAGS.
- if len(fd.d.rootHash) != 0 {
+ // All enabled files should store a hash. This flag is not settable via
+ // FS_IOC_SETFLAGS.
+ if len(fd.d.hash) != 0 {
f |= linux.FS_VERITY_FL
}
t := kernel.TaskFromContext(ctx)
- addr := args[2].Pointer()
- _, err := primitive.CopyInt32Out(t, addr, f)
+ _, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -629,11 +710,15 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar
func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FS_IOC_ENABLE_VERITY:
- return fd.enableVerity(ctx, uio, args)
+ return fd.enableVerity(ctx, uio)
+ case linux.FS_IOC_MEASURE_VERITY:
+ return fd.measureVerity(ctx, uio, args[2].Pointer())
case linux.FS_IOC_GETFLAGS:
- return fd.getFlags(ctx, uio, args)
+ return fd.verityFlags(ctx, uio, args[2].Pointer())
default:
- return fd.lowerFD.Ioctl(ctx, uio, args)
+ // TODO(b/169682228): Investigate which ioctl commands should
+ // be allowed.
+ return 0, syserror.ENOSYS
}
}
@@ -641,10 +726,12 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
// allowRuntimeEnable mode.
- if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 {
+ if !isEnabled(fd.d) {
return fd.lowerFD.PRead(ctx, dst, offset, opts)
}
+ fd.d.fs.verityMu.RLock()
+ defer fd.d.fs.verityMu.RUnlock()
// dataSize is the size of the whole file.
dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
@@ -678,9 +765,22 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Ctx: ctx,
}
- n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */)
+ n, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ ReadOffset: offset,
+ ReadSize: dst.NumBytes(),
+ Expected: fd.d.hash,
+ DataAndTreeInSameFile: false,
+ })
if err != nil {
- return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
new file mode 100644
index 000000000..e301d35f5
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -0,0 +1,491 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// rootMerkleFilename is the name of the root Merkle tree file.
+const rootMerkleFilename = "root.verity"
+
+// maxDataSize is the maximum data size written to the file for test.
+const maxDataSize = 100000
+
+// newVerityRoot creates a new verity mount, and returns the root. The
+// underlying file system is tmpfs. If the error is not nil, then cleanup
+// should be called when the root is no longer needed.
+func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+ rand.Seed(time.Now().UnixNano())
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: InternalFilesystemOptions{
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ AllowRuntimeEnable: true,
+ NoCrashOnVerificationFailure: true,
+ },
+ },
+ })
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ }
+ root := mntns.Root()
+ root.IncRef()
+ t.Helper()
+ t.Cleanup(func() {
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
+ })
+ return vfsObj, root, nil
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+
+ // Create the file in the underlying file system.
+ lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Generate random data to be written to the file.
+ dataSize := rand.Intn(maxDataSize) + 1
+ data := make([]byte, dataSize)
+ rand.Read(data)
+
+ // Write directly to the underlying FD, since verity FD is read-only.
+ n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ if n != int64(len(data)) {
+ return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data))
+ }
+
+ lowerFD.DecRef(ctx)
+
+ // Now open the verity file descriptor.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular | mode,
+ })
+ return fd, dataSize, err
+}
+
+// corruptRandomBit randomly flips a bit in the file represented by fd.
+func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
+ // Flip a random bit in the underlying file.
+ randomPos := int64(rand.Intn(size))
+ byteToModify := make([]byte, 1)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PRead: %v", err)
+ }
+ byteToModify[0] ^= 1
+ if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PWrite: %v", err)
+ }
+ return nil
+}
+
+// TestOpen ensures that when a file is created, the corresponding Merkle tree
+// file and the root Merkle tree file exist.
+func TestOpen(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
+}
+
+// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
+// succeeds after enabling verity for it.
+func TestReadUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
+}
+
+// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
+// succeeds after enabling verity for it.
+func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
+}
+
+// TestModifiedFileFails ensures that read from a modified verity file fails.
+func TestModifiedFileFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded with modified file")
+ }
+}
+
+// TestModifiedMerkleFails ensures that read from a verity file fails if the
+// corresponding Merkle tree file is modified.
+func TestModifiedMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
+}
+
+// TestModifiedParentMerkleFails ensures that open a verity enabled file in a
+// verity enabled directory fails if the hashes related to the target file in
+// the parent Merkle tree file is modified.
+func TestModifiedParentMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
+}
+
+// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
+// succeeds after enabling verity for it.
+func TestUnmodifiedStatSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
+}
+
+// TestModifiedStatFails checks that getting stat for a file with modified stat
+// should fail.
+func TestModifiedStatFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
+}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 61c78569d..300b7ccce 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -7,11 +7,14 @@ go_library(
srcs = [
"cgroup.go",
"hostmm.go",
+ "membarrier.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
"//pkg/fd",
"//pkg/log",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/hostmm/membarrier.go b/pkg/sentry/hostmm/membarrier.go
new file mode 100644
index 000000000..4468d75f1
--- /dev/null
+++ b/pkg/sentry/hostmm/membarrier.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostmm
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ haveMembarrierGlobal = false
+ haveMembarrierPrivateExpedited = false
+)
+
+func init() {
+ supported, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_QUERY, 0 /* flags */, 0 /* unused */)
+ if e != 0 {
+ if e != syscall.ENOSYS {
+ log.Warningf("membarrier(MEMBARRIER_CMD_QUERY) failed: %s", e.Error())
+ }
+ return
+ }
+ // We don't use MEMBARRIER_CMD_GLOBAL_EXPEDITED because this sends IPIs to
+ // all CPUs running tasks that have previously invoked
+ // MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, which presents a DOS risk.
+ // (MEMBARRIER_CMD_GLOBAL is synchronize_rcu(), i.e. it waits for an RCU
+ // grace period to elapse without bothering other CPUs.
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED sends IPIs only to CPUs running tasks
+ // sharing the caller's MM.)
+ if supported&linux.MEMBARRIER_CMD_GLOBAL != 0 {
+ haveMembarrierGlobal = true
+ }
+ if req := uintptr(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED); supported&req == req {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ log.Warningf("membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) failed: %s", e.Error())
+ } else {
+ haveMembarrierPrivateExpedited = true
+ }
+ }
+}
+
+// HaveGlobalMemoryBarrier returns true if GlobalMemoryBarrier is supported.
+func HaveGlobalMemoryBarrier() bool {
+ return haveMembarrierGlobal
+}
+
+// GlobalMemoryBarrier blocks until "all running threads [in the host OS] have
+// passed through a state where all memory accesses to user-space addresses
+// match program order between entry to and return from [GlobalMemoryBarrier]",
+// as for membarrier(2).
+//
+// Preconditions: HaveGlobalMemoryBarrier() == true.
+func GlobalMemoryBarrier() error {
+ if _, _, e := syscall.Syscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_GLOBAL, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
+
+// HaveProcessMemoryBarrier returns true if ProcessMemoryBarrier is supported.
+func HaveProcessMemoryBarrier() bool {
+ return haveMembarrierPrivateExpedited
+}
+
+// ProcessMemoryBarrier is equivalent to GlobalMemoryBarrier, but only
+// synchronizes with threads sharing a virtual address space (from the host OS'
+// perspective) with the calling thread.
+//
+// Preconditions: HaveProcessMemoryBarrier() == true.
+func ProcessMemoryBarrier() error {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 083071b5e..5de70aecb 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -204,7 +204,6 @@ go_library(
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/amutex",
- "//pkg/binary",
"//pkg/bits",
"//pkg/bpf",
"//pkg/context",
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index d3e76ca7b..060c056df 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -89,7 +89,7 @@ func (kcov *Kcov) TaskWork(t *Task) {
kcov.mu.Lock()
defer kcov.mu.Unlock()
- if kcov.mode != linux.KCOV_TRACE_PC {
+ if kcov.mode != linux.KCOV_MODE_TRACE_PC {
return
}
@@ -146,7 +146,7 @@ func (kcov *Kcov) InitTrace(size uint64) error {
}
// EnableTrace performs the KCOV_ENABLE_TRACE ioctl.
-func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
+func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error {
t := TaskFromContext(ctx)
if t == nil {
panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine")
@@ -160,9 +160,9 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
return syserror.EINVAL
}
- switch traceMode {
+ switch traceKind {
case linux.KCOV_TRACE_PC:
- kcov.mode = traceMode
+ kcov.mode = linux.KCOV_MODE_TRACE_PC
case linux.KCOV_TRACE_CMP:
// We do not support KCOV_MODE_TRACE_CMP.
return syserror.ENOTSUP
@@ -175,6 +175,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
}
kcov.owningTask = t
+ t.SetKcov(kcov)
t.RegisterWork(kcov)
// Clear existing coverage data; the task expects to read only coverage data
@@ -196,26 +197,35 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error {
if t != kcov.owningTask {
return syserror.EINVAL
}
- kcov.owningTask = nil
kcov.mode = linux.KCOV_MODE_INIT
- kcov.resetLocked()
+ kcov.owningTask = nil
+ kcov.mappable = nil
return nil
}
-// Reset is called when the owning task exits.
-func (kcov *Kcov) Reset() {
+// Clear resets the mode and clears the owning task and memory mapping for kcov.
+// It is called when the fd corresponding to kcov is closed. Note that the mode
+// needs to be set so that the next call to kcov.TaskWork() will exit early.
+func (kcov *Kcov) Clear() {
kcov.mu.Lock()
- kcov.resetLocked()
+ kcov.clearLocked()
kcov.mu.Unlock()
}
-// The kcov instance is reset when the owning task exits or when tracing is
-// disabled.
-func (kcov *Kcov) resetLocked() {
+func (kcov *Kcov) clearLocked() {
+ kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
- if kcov.mappable != nil {
- kcov.mappable = nil
- }
+ kcov.mappable = nil
+}
+
+// OnTaskExit is called when the owning task exits. It is similar to
+// kcov.Clear(), except the memory mapping is not cleared, so that the same
+// mapping can be used in the future if kcov is enabled again by another task.
+func (kcov *Kcov) OnTaskExit() {
+ kcov.mu.Lock()
+ kcov.mode = linux.KCOV_MODE_INIT
+ kcov.owningTask = nil
+ kcov.mu.Unlock()
}
// ConfigureMMap is called by the vfs.FileDescription for this kcov instance to
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index d6c21adb7..675506269 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -841,14 +841,16 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
if ctx.args.MountNamespaceVFS2 == nil {
return nil
}
- // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
- return ctx.args.MountNamespaceVFS2.Root()
+ root := ctx.args.MountNamespaceVFS2.Root()
+ root.IncRef()
+ return root
case vfs.CtxMountNamespace:
if ctx.k.globalInit == nil {
return nil
}
- // MountNamespaceVFS2 takes a reference for us.
- return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns.IncRef()
+ return mntns
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case inet.CtxStack:
@@ -904,14 +906,13 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
if VFS2Enabled {
mntnsVFS2 = args.MountNamespaceVFS2
if mntnsVFS2 == nil {
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
+ // Add a reference to the namespace, which is transferred to the new process.
mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
+ mntnsVFS2.IncRef()
}
// Get the root directory from the MountNamespace.
root := mntnsVFS2.Root()
- // The call to newFSContext below will take a reference on root, so we
- // don't need to hold this one.
+ root.IncRef()
defer root.DecRef(ctx)
// Grab the working directory.
@@ -1648,16 +1649,16 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
if ctx.k.globalInit == nil {
return vfs.VirtualDentry{}
}
- mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
- defer mntns.DecRef(ctx)
- // Root() takes a reference on the root dirent for us.
- return mntns.Root()
+ root := ctx.k.GlobalInit().Leader().MountNamespaceVFS2().Root()
+ root.IncRef()
+ return root
case vfs.CtxMountNamespace:
if ctx.k.globalInit == nil {
return nil
}
- // MountNamespaceVFS2() takes a reference for us.
- return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntns.IncRef()
+ return mntns
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case inet.CtxStack:
@@ -1738,3 +1739,20 @@ func (k *Kernel) ShmMount() *vfs.Mount {
func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+
+// Release releases resources owned by k.
+//
+// Precondition: This should only be called after the kernel is fully
+// initialized, e.g. after k.Start() has been called.
+func (k *Kernel) Release() {
+ ctx := k.SupervisorContext()
+ if VFS2Enabled {
+ k.hostMount.DecRef(ctx)
+ k.pipeMount.DecRef(ctx)
+ k.shmMount.DecRef(ctx)
+ k.socketMount.DecRef(ctx)
+ k.vfs.Release(ctx)
+ }
+ k.timekeeper.Destroy()
+ k.vdso.Release(ctx)
+}
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index c38c5a40c..387edfa91 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -18,7 +18,6 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
@@ -27,25 +26,18 @@ import (
const maxSyscallFilterInstructions = 1 << 15
-// seccompData is equivalent to struct seccomp_data, which contains the data
-// passed to seccomp-bpf filters.
-type seccompData struct {
- // nr is the system call number.
- nr int32
-
- // arch is an AUDIT_ARCH_* value indicating the system call convention.
- arch uint32
-
- // instructionPointer is the value of the instruction pointer at the time
- // of the system call.
- instructionPointer uint64
-
- // args contains the first 6 system call arguments.
- args [6]uint64
-}
-
-func (d *seccompData) asBPFInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder}
+// dataAsBPFInput returns a serialized BPF program, only valid on the current task
+// goroutine.
+//
+// Note: this is called for every syscall, which is a very hot path.
+func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
+ buf := t.CopyScratchBuffer(d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ // Go-marshal always uses the native byte order.
+ Order: usermem.ByteOrder,
+ }
}
func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
@@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u
}
func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
- data := seccompData{
- nr: sysno,
- arch: t.tc.st.AuditNumber,
- instructionPointer: uint64(ip),
+ data := linux.SeccompData{
+ Nr: sysno,
+ Arch: t.tc.st.AuditNumber,
+ InstructionPointer: uint64(ip),
}
// data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so
// we can't do any slicing tricks or even use copy/append here.
for i, arg := range args {
- if i >= len(data.args) {
+ if i >= len(data.Args) {
break
}
- data.args[i] = arg.Uint64()
+ data.Args[i] = arg.Uint64()
}
- input := data.asBPFInput()
+ input := dataAsBPFInput(t, &data)
ret := uint32(linux.SECCOMP_RET_ALLOW)
f := t.syscallFilters.Load()
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 3eb78e91b..76d472292 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index b07e1c1bd..78f718cfe 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -17,7 +17,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
@@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a436610c9..e90a19cfb 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -735,7 +735,6 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
func (t *Task) IsChrooted() bool {
if VFS2Enabled {
realRoot := t.mountNamespaceVFS2.Root()
- defer realRoot.DecRef(t)
root := t.fsContext.RootDirectoryVFS2()
defer root.DecRef(t)
return root != realRoot
@@ -868,7 +867,6 @@ func (t *Task) MountNamespace() *fs.MountNamespace {
func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
t.mu.Lock()
defer t.mu.Unlock()
- t.mountNamespaceVFS2.IncRef()
return t.mountNamespaceVFS2
}
@@ -917,7 +915,7 @@ func (t *Task) SetKcov(k *Kcov) {
// ResetKcov clears the kcov instance associated with t.
func (t *Task) ResetKcov() {
if t.kcov != nil {
- t.kcov.Reset()
+ t.kcov.OnTaskExit()
t.kcov = nil
}
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 5ae5906e8..fdadb52c0 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -265,6 +265,13 @@ func (ns *PIDNamespace) Tasks() []*Task {
return tasks
}
+// NumTasks returns the number of tasks in ns.
+func (ns *PIDNamespace) NumTasks() int {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return len(ns.tids)
+}
+
// ThreadGroups returns a snapshot of the thread groups in ns.
func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup {
return ns.ThreadGroupsAppend(nil)
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index e44a139b3..9bc452e67 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -17,7 +17,6 @@ package kernel
import (
"fmt"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -28,6 +27,8 @@ import (
//
// They are exposed to the VDSO via a parameter page managed by VDSOParamPage,
// which also includes a sequence counter.
+//
+// +marshal
type vdsoParams struct {
monotonicReady uint64
monotonicBaseCycles int64
@@ -68,6 +69,13 @@ type VDSOParamPage struct {
// checked in state_test_util tests, causing this field to change across
// save / restore.
seq uint64
+
+ // copyScratchBuffer is a temporary buffer used to marshal the params before
+ // copying it to the real parameter page. The parameter page is typically
+ // updated at a moderate frequency of ~O(seconds) throughout the lifetime of
+ // the sentry, so reusing this buffer is a good tradeoff between memory
+ // usage and the cost of allocation.
+ copyScratchBuffer []byte
}
// NewVDSOParamPage returns a VDSOParamPage.
@@ -79,7 +87,11 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
- return &VDSOParamPage{mfp: mfp, fr: fr}
+ return &VDSOParamPage{
+ mfp: mfp,
+ fr: fr,
+ copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()),
+ }
}
// access returns a mapping of the param page.
@@ -133,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error {
// Get the new params.
p := f()
- buf := binary.Marshal(nil, usermem.ByteOrder, p)
+ buf := v.copyScratchBuffer[:p.SizeBytes()]
+ p.MarshalUnsafe(buf)
// Skip the sequence counter.
if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil {
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 05a294fe6..241d87835 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -380,3 +380,9 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
return vdsoAddr, nil
}
+
+// Release drops references on mappings held by v.
+func (v *VDSO) Release(ctx context.Context) {
+ v.ParamPage.DecRef(ctx)
+ v.vdso.DecRef(ctx)
+}
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index a44fa2b95..7fd77925f 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -127,7 +127,7 @@ func (t Translation) FileRange() FileRange {
// Preconditions: Same as Mappable.Translate.
func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
// Verify that the inputs to Mappable.Translate were valid.
- if !required.WellFormed() || required.Length() <= 0 {
+ if !required.WellFormed() || required.Length() == 0 {
panic(fmt.Sprintf("invalid required range: %v", required))
}
if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
@@ -145,7 +145,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required)
}
for i, t := range ts {
- if !t.Source.WellFormed() || t.Source.Length() <= 0 {
+ if !t.Source.WellFormed() || t.Source.Length() == 0 {
return fmt.Errorf("Translation %+v has invalid Source", t)
}
if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 8c9f11cce..92cc87d84 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -235,6 +235,20 @@ type MemoryManager struct {
// vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
vdsoSigReturnAddr uint64
+
+ // membarrierPrivateEnabled is non-zero if EnableMembarrierPrivate has
+ // previously been called. Since, as of this writing,
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED is implemented as a global memory
+ // barrier, membarrierPrivateEnabled has no other effect.
+ //
+ // membarrierPrivateEnabled is accessed using atomic memory operations.
+ membarrierPrivateEnabled uint32
+
+ // membarrierRSeqEnabled is non-zero if EnableMembarrierRSeq has previously
+ // been called.
+ //
+ // membarrierRSeqEnabled is accessed using atomic memory operations.
+ membarrierRSeqEnabled uint32
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 30facebf7..7e5f7de64 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -36,7 +36,7 @@ import (
// * ar.Length() != 0.
func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -100,7 +100,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user
// (i.e. permission checks must have been performed against vmas).
func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -193,7 +193,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR
// getVecPMAsLocked; other clients should call one of those instead.
func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -223,7 +223,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
// Need a pma here.
optAR := vseg.Range().Intersect(pgap.Range())
if checkInvariants {
- if optAR.Length() <= 0 {
+ if optAR.Length() == 0 {
panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap))
}
}
@@ -560,7 +560,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat
// Invalidate implements memmap.MappingSpace.Invalidate.
func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -583,7 +583,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate
// * ar must be page-aligned.
func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -629,7 +629,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
// * ar must be page-aligned.
func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -715,10 +715,10 @@ func Unpin(prs []PinnedRange) {
// * oldAR and newAR must be page-aligned.
func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
if checkInvariants {
- if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() {
+ if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() {
panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
}
- if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() {
+ if !newAR.WellFormed() || newAR.Length() == 0 || !newAR.IsPageAligned() {
panic(fmt.Sprintf("invalid newAR: %v", newAR))
}
if oldAR.Length() > newAR.Length() {
@@ -778,7 +778,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
// into mm.pmas.
func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -831,7 +831,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe
// * pseg.Range().Contains(ar.Start).
func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -1050,7 +1050,7 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if !pseg.Ok() {
panic("terminal pma iterator")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().IsSupersetOf(ar) {
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index a2555ba1a..675efdc7c 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -17,6 +17,7 @@ package mm
import (
"fmt"
mrand "math/rand"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -1274,3 +1275,27 @@ func (mm *MemoryManager) VirtualDataSize() uint64 {
defer mm.mappingMu.RUnlock()
return mm.dataAS
}
+
+// EnableMembarrierPrivate causes future calls to IsMembarrierPrivateEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierPrivate() {
+ atomic.StoreUint32(&mm.membarrierPrivateEnabled, 1)
+}
+
+// IsMembarrierPrivateEnabled returns true if mm.EnableMembarrierPrivate() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierPrivateEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierPrivateEnabled) != 0
+}
+
+// EnableMembarrierRSeq causes future calls to IsMembarrierRSeqEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierRSeq() {
+ atomic.StoreUint32(&mm.membarrierRSeqEnabled, 1)
+}
+
+// IsMembarrierRSeqEnabled returns true if mm.EnableMembarrierRSeq() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierRSeqEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierRSeqEnabled) != 0
+}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index f769d8294..b8df72813 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -266,7 +266,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
// * ar.Length() != 0.
func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -350,7 +350,7 @@ const guardBytes = 256 * usermem.PageSize
// * ar must be page-aligned.
func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -371,7 +371,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange)
// * ar must be page-aligned.
func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -511,7 +511,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Range().IsSupersetOf(ar) {
@@ -536,7 +536,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !mr.WellFormed() || mr.Length() <= 0 {
+ if !mr.WellFormed() || mr.Length() == 0 {
panic(fmt.Sprintf("invalid mr: %v", mr))
}
if !vseg.mappableRange().IsSupersetOf(mr) {
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 209b28053..db7d55ef2 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/context",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/hostmm",
"//pkg/sentry/memmap",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 40fe4228a..8ce411102 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -9,12 +9,12 @@ go_library(
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
- "bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
"bluepill_arm64.go",
"bluepill_arm64.s",
"bluepill_arm64_unsafe.go",
"bluepill_fault.go",
+ "bluepill_impl_amd64.s",
"bluepill_unsafe.go",
"context.go",
"filters_amd64.go",
@@ -56,6 +56,7 @@ go_library(
"//pkg/sentry/time",
"//pkg/sync",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -79,6 +80,15 @@ go_test(
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/sentry/time",
"//pkg/usermem",
],
)
+
+genrule(
+ name = "bluepill_impl_amd64",
+ srcs = ["bluepill_amd64.s"],
+ outs = ["bluepill_impl_amd64.s"],
+ cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 2bc34a435..025ea93b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -19,11 +19,6 @@
// This is guaranteed to be zero.
#define VCPU_CPU 0x0
-// CPU_SELF is the self reference in ring0's percpu.
-//
-// This is guaranteed to be zero.
-#define CPU_SELF 0x0
-
// Context offsets.
//
// Only limited use of the context is done in the assembly stub below, most is
@@ -44,7 +39,7 @@ begin:
LEAQ VCPU_CPU(AX), BX
BYTE CLI;
check_vcpu:
- MOVQ CPU_SELF(GS), CX
+ MOVQ ENTRY_CPU_SELF(GS), CX
CMPQ BX, CX
JE right_vCPU
wrong_vcpu:
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 03a98512e..0a54dd30d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
- return c.runData.readyForInterruptInjection != 0
+ if c.runData.readyForInterruptInjection == 0 {
+ return false
+ }
+
+ if c.runData.ifFlag == 0 {
+ // This is impossible if readyForInterruptInjection is 1.
+ throw("interrupts are disabled")
+ }
+
+ // Disable interrupts if we are in the kernel space.
+ //
+ // When the Sentry switches into the kernel mode, it disables
+ // interrupts. But when goruntime switches on a goroutine which has
+ // been saved in the host mode, it restores flags and this enables
+ // interrupts. See the comment of UserFlagsSet for more details.
+ uregs := userRegs{}
+ err := c.getUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to get user registers")
+ }
+
+ if ring0.IsKernelFlags(uregs.RFLAGS) {
+ uregs.RFLAGS &^= ring0.KernelFlagsClear
+ err = c.setUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to set user registers")
+ }
+ return false
+ }
+ return true
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 979be5d89..eb05950cd 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -62,6 +62,9 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
//
//go:nosplit
func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Increment our counter.
+ atomic.AddUint64(&c.guestExits, 1)
+
// Copy out registers.
bluepillArchExit(c, bluepillArchContext(context))
@@ -89,9 +92,6 @@ func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
c := bluepillArchEnter(bluepillArchContext(context))
- // Increment the number of switches.
- atomic.AddUint32(&c.switches, 1)
-
// Mark this as guest mode.
switch atomic.SwapUint32(&c.state, vCPUGuest|vCPUUser) {
case vCPUUser: // Expected case.
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index 6e6b76416..17268d127 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -15,6 +15,8 @@
package kvm
import (
+ "sync/atomic"
+
pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -75,6 +77,9 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
// Clear the address space.
cpu.active.set(nil)
+ // Increment the number of user exits.
+ atomic.AddUint64(&cpu.userExits, 1)
+
// Release resources.
c.machine.Put(cpu)
diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go
index 7d949f1dd..d3d216aa5 100644
--- a/pkg/sentry/platform/kvm/filters_amd64.go
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
@@ -17,14 +17,23 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: {},
- syscall.SYS_IOCTL: {},
+ syscall.SYS_ARCH_PRCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
index 9245d07c2..21abc2a3d 100644
--- a/pkg/sentry/platform/kvm/filters_arm64.go
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -17,13 +17,22 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_IOCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ae813e24e..dd45ad10b 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -63,6 +63,9 @@ type runData struct {
type KVM struct {
platform.NoCPUPreemptionDetection
+ // KVM never changes mm_structs.
+ platform.UseHostProcessMemoryBarrier
+
// machine is the backing VM.
machine *machine
}
@@ -156,15 +159,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
pageTables := pagetables.New(newAllocator())
- applyPhysicalRegions(func(pr physicalRegion) bool {
- // Map the kernel in the upper half.
- pageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
- return true // Keep iterating.
- })
+ k.machine.mapUpperHalf(pageTables)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_arm64_test.go b/pkg/sentry/platform/kvm/kvm_arm64_test.go
index d34553a96..0e3d84d95 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64_test.go
@@ -23,10 +23,9 @@ import (
)
func TestKernelTLS(t *testing.T) {
- bluepillTest(t, func(c *vCPU) {
- if !testutil.TLSWorks() {
- t.Errorf("tls does not work, and it should!")
- }
- })
+ bluepillTest(t, func(c *vCPU) {
+ if !testutil.TLSWorks() {
+ t.Errorf("tls does not work, and it should!")
+ }
+ })
}
-
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 5c4b18899..6abaa21c4 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -26,12 +26,16 @@ const (
_KVM_RUN = 0xae80
_KVM_NMI = 0xae9a
_KVM_CHECK_EXTENSION = 0xae03
+ _KVM_GET_TSC_KHZ = 0xaea3
+ _KVM_SET_TSC_KHZ = 0xaea2
_KVM_INTERRUPT = 0x4004ae86
_KVM_SET_MSRS = 0x4008ae89
_KVM_SET_USER_MEMORY_REGION = 0x4020ae46
_KVM_SET_REGS = 0x4090ae82
_KVM_SET_SREGS = 0x4138ae84
+ _KVM_GET_MSRS = 0xc008ae88
_KVM_GET_REGS = 0x8090ae81
+ _KVM_GET_SREGS = 0x8138ae83
_KVM_GET_SUPPORTED_CPUID = 0xc008ae05
_KVM_SET_CPUID2 = 0x4008ae90
_KVM_SET_SIGNAL_MASK = 0x4004ae8b
@@ -79,11 +83,14 @@ const (
)
// KVM hypercall list.
+//
// Canonical list of hypercalls supported.
const (
// On amd64, it uses 'HLT' to leave the guest.
+ //
// Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
- // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ //
+ // _KVM_HYPERCALL_VMEXIT is only used on arm64 for now.
_KVM_HYPERCALL_VMEXIT int = iota
_KVM_HYPERCALL_MAX
)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 9a7be3655..84df0f878 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -101,13 +101,20 @@ const (
// Arm64: Memory Attribute Indirection Register EL1.
const (
- _MT_DEVICE_nGnRnE = 0
- _MT_DEVICE_nGnRE = 1
- _MT_DEVICE_GRE = 2
- _MT_NORMAL_NC = 3
- _MT_NORMAL = 4
- _MT_NORMAL_WT = 5
- _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_ATTR_DEVICE_nGnRnE = 0x00
+ _MT_ATTR_DEVICE_nGnRE = 0x04
+ _MT_ATTR_DEVICE_GRE = 0x0c
+ _MT_ATTR_NORMAL_NC = 0x44
+ _MT_ATTR_NORMAL_WT = 0xbb
+ _MT_ATTR_NORMAL = 0xff
+ _MT_ATTR_MASK = 0xff
+ _MT_EL1_INIT = (_MT_ATTR_DEVICE_nGnRnE << (_MT_DEVICE_nGnRnE * 8)) | (_MT_ATTR_DEVICE_nGnRE << (_MT_DEVICE_nGnRE * 8)) | (_MT_ATTR_DEVICE_GRE << (_MT_DEVICE_GRE * 8)) | (_MT_ATTR_NORMAL_NC << (_MT_NORMAL_NC * 8)) | (_MT_ATTR_NORMAL << (_MT_NORMAL * 8)) | (_MT_ATTR_NORMAL_WT << (_MT_NORMAL_WT * 8))
)
const (
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 45b3180f1..e58acc071 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -411,9 +412,9 @@ func TestWrongVCPU(t *testing.T) {
// Basic test, one then the other.
bluepill(c1)
bluepill(c2)
- if c2.switches == 0 {
+ if c2.guestExits == 0 {
// Don't allow the test to proceed if this fails.
- t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ t.Fatalf("wrong vCPU#2 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
// Alternate vCPUs; we expect to need to trigger the
@@ -422,11 +423,11 @@ func TestWrongVCPU(t *testing.T) {
bluepill(c1)
bluepill(c2)
}
- if count := c1.switches; count < 90 {
- t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ if count := c1.guestExits; count < 90 {
+ t.Errorf("wrong vCPU#1 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
- if count := c2.switches; count < 90 {
- t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2)
+ if count := c2.guestExits; count < 90 {
+ t.Errorf("wrong vCPU#2 exits: vCPU1=%+v,vCPU2=%+v", c1, c2)
}
return false
})
@@ -442,6 +443,22 @@ func TestWrongVCPU(t *testing.T) {
})
}
+func TestRdtsc(t *testing.T) {
+ var i int // Iteration count.
+ kvmTest(t, nil, func(c *vCPU) bool {
+ start := ktime.Rdtsc()
+ bluepill(c)
+ guest := ktime.Rdtsc()
+ redpill()
+ end := ktime.Rdtsc()
+ if start > guest || guest > end {
+ t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end)
+ }
+ i++
+ return i < 100
+ })
+}
+
func BenchmarkApplicationSyscall(b *testing.B) {
var (
i int // Iteration includes machine.Get() / machine.Put().
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 372a4cbd7..61ed24d01 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -103,8 +103,11 @@ type vCPU struct {
// tid is the last set tid.
tid uint64
- // switches is a count of world switches (informational only).
- switches uint32
+ // userExits is the count of user exits.
+ userExits uint64
+
+ // guestExits is the count of guest to host world switches.
+ guestExits uint64
// faults is a count of world faults (informational only).
faults uint32
@@ -127,6 +130,7 @@ type vCPU struct {
// vCPUArchState is the architecture-specific state.
vCPUArchState
+ // dieState holds state related to vCPU death.
dieState dieState
}
@@ -155,7 +159,7 @@ func (m *machine) newVCPU() *vCPU {
fd: int(fd),
machine: m,
}
- c.CPU.Init(&m.kernel, c)
+ c.CPU.Init(&m.kernel, c.id, c)
m.vCPUsByID[c.id] = c
// Ensure the signal mask is correct.
@@ -183,9 +187,6 @@ func newMachine(vm int) (*machine, error) {
// Create the machine.
m := &machine{fd: vm}
m.available.L = &m.mu
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- })
// Pull the maximum vCPUs.
maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
@@ -197,6 +198,9 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ }, m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -219,15 +223,9 @@ func newMachine(vm int) (*machine, error) {
pagetables.MapOpts{AccessType: usermem.AnyAccess},
pr.physical)
- // And keep everything in the upper half.
- m.kernel.PageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
-
return true // Keep iterating.
})
+ m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -365,6 +363,11 @@ func (m *machine) Destroy() {
// Get gets an available vCPU.
//
// This will return with the OS thread locked.
+//
+// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points
+// to the vCPU in which the OS thread TID is running. So if Get() returns with
+// the corrent context in guest, the vCPU of it must be the same as what
+// Get() returns.
func (m *machine) Get() *vCPU {
m.mu.RLock()
runtime.LockOSThread()
@@ -469,6 +472,19 @@ func (m *machine) newDirtySet() *dirtySet {
}
}
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
+ }
+}
+
// lock marks the vCPU as in user mode.
//
// This should only be called directly when known to be safe, i.e. when
@@ -528,6 +544,8 @@ var pid = syscall.Getpid()
//
// This effectively unwinds the state machine.
func (c *vCPU) bounce(forceGuestExit bool) {
+ origGuestExits := atomic.LoadUint64(&c.guestExits)
+ origUserExits := atomic.LoadUint64(&c.userExits)
for {
switch state := atomic.LoadUint32(&c.state); state {
case vCPUReady, vCPUWaiter:
@@ -583,6 +601,14 @@ func (c *vCPU) bounce(forceGuestExit bool) {
// Should not happen: the above is exhaustive.
panic("invalid state")
}
+
+ // Check if we've missed the state transition, but
+ // we can safely return at this point in time.
+ newGuestExits := atomic.LoadUint64(&c.guestExits)
+ newUserExits := atomic.LoadUint64(&c.userExits)
+ if newUserExits != origUserExits && (!forceGuestExit || newGuestExits != origGuestExits) {
+ return
+ }
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index acc823ba6..c67127d95 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -18,14 +18,17 @@ package kvm
import (
"fmt"
+ "math/big"
"reflect"
"runtime/debug"
"syscall"
+ "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -84,19 +87,6 @@ const (
poolPCIDs = 8
)
-// dropPageTables drops cached page table entries.
-func (m *machine) dropPageTables(pt *pagetables.PageTables) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- // Clear from all PCIDs.
- for _, c := range m.vCPUsByID {
- if c != nil && c.PCIDs != nil {
- c.PCIDs.Drop(pt)
- }
- }
-}
-
// initArchState initializes architecture-specific state.
func (c *vCPU) initArchState() error {
var (
@@ -144,6 +134,7 @@ func (c *vCPU) initArchState() error {
// Set the entrypoint for the kernel.
kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RSP = c.StackTop()
kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
// Set the system registers.
@@ -152,8 +143,8 @@ func (c *vCPU) initArchState() error {
}
// Set the user registers.
- if err := c.setUserRegisters(&kernelUserRegs); err != nil {
- return err
+ if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
}
// Allocate some floating point state save area for the local vCPU.
@@ -166,6 +157,133 @@ func (c *vCPU) initArchState() error {
return c.setSystemTime()
}
+// bitsForScaling returns the bits available for storing the fraction component
+// of the TSC scaling ratio. This allows us to replicate the (bad) math done by
+// the kernel below in scaledTSC, and ensure we can compute an exact zero
+// offset in setSystemTime.
+//
+// These constants correspond to kvm_tsc_scaling_ratio_frac_bits.
+var bitsForScaling = func() int64 {
+ fs := cpuid.HostFeatureSet()
+ if fs.Intel() {
+ return 48 // See vmx.c (kvm sources).
+ } else if fs.AMD() {
+ return 32 // See svm.c (svm sources).
+ } else {
+ return 63 // Unknown: theoretical maximum.
+ }
+}()
+
+// scaledTSC returns the host TSC scaled by the given frequency.
+//
+// This assumes a current frequency of 1. We require only the unitless ratio of
+// rawFreq to some current frequency. See setSystemTime for context.
+//
+// The kernel math guarantees that all bits of the multiplication and division
+// will be correctly preserved and applied. However, it is not possible to
+// actually store the ratio correctly. So we need to use the same schema in
+// order to calculate the scaled frequency and get the same result.
+//
+// We can assume that the current frequency is (1), so we are calculating a
+// strict inverse of this value. This simplifies this function considerably.
+//
+// Roughly, the returned value "scaledTSC" will have:
+// scaledTSC/hostTSC == 1/rawFreq
+//
+//go:nosplit
+func scaledTSC(rawFreq uintptr) int64 {
+ scale := int64(1 << bitsForScaling)
+ ratio := big.NewInt(scale / int64(rawFreq))
+ ratio.Mul(ratio, big.NewInt(int64(ktime.Rdtsc())))
+ ratio.Div(ratio, big.NewInt(scale))
+ return ratio.Int64()
+}
+
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ // First, scale down the clock frequency to the lowest value allowed by
+ // the API itself. How low we can go depends on the underlying
+ // hardware, but it is typically ~1/2^48 for Intel, ~1/2^32 for AMD.
+ // Even the lower bound here will take a 4GHz frequency down to 1Hz,
+ // meaning that everything should be able to handle a Khz setting of 1
+ // with bits to spare.
+ //
+ // Note that reducing the clock does not typically require special
+ // capabilities as it is emulated in KVM. We don't actually use this
+ // capability, but it means that this method should be robust to
+ // different hardware configurations.
+ rawFreq, err := c.getTSCFreq()
+ if err != nil {
+ return c.setSystemTimeLegacy()
+ }
+ if err := c.setTSCFreq(1); err != nil {
+ return c.setSystemTimeLegacy()
+ }
+
+ // Always restore the original frequency.
+ defer func() {
+ if err := c.setTSCFreq(rawFreq); err != nil {
+ panic(err.Error())
+ }
+ }()
+
+ // Attempt to set the system time in this compressed world. The
+ // calculation for offset normally looks like:
+ //
+ // offset = target_tsc - kvm_scale_tsc(vcpu, rdtsc());
+ //
+ // So as long as the kvm_scale_tsc component is constant before and
+ // after the call to set the TSC value (and it is passes as the
+ // target_tsc), we will compute an offset value of zero.
+ //
+ // This is effectively cheating to make our "setSystemTime" call so
+ // unbelievably, incredibly fast that we do it "instantly" and all the
+ // calculations result in an offset of zero.
+ lastTSC := scaledTSC(rawFreq)
+ for {
+ if err := c.setTSC(uint64(lastTSC)); err != nil {
+ return err
+ }
+ nextTSC := scaledTSC(rawFreq)
+ if lastTSC == nextTSC {
+ return nil
+ }
+ lastTSC = nextTSC // Try again.
+ }
+}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
+
// nonCanonical generates a canonical address return.
//
//go:nosplit
@@ -345,3 +463,41 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
+
+var execRegions = func() (regions []region) {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
+ return
+ }
+ if vr.accessType.Execute {
+ regions = append(regions, vr.region)
+ }
+ })
+ return
+}()
+
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ for _, r := range execRegions {
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossilbe translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
+ }
+ for start, end := range m.kernel.EntryRegions() {
+ regionLen := end - start
+ physical, length, ok := translateToPhysical(start)
+ if !ok || length < regionLen {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|start),
+ regionLen,
+ pagetables.MapOpts{AccessType: usermem.ReadWrite},
+ physical)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 290f035dd..b430f92c6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -23,7 +23,6 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/time"
)
// loadSegments copies the current segments.
@@ -61,91 +60,63 @@ func (c *vCPU) setCPUID() error {
return nil
}
-// setSystemTime sets the TSC for the vCPU.
+// getTSCFreq gets the TSC frequency.
//
-// This has to make the call many times in order to minimize the intrinsic
-// error in the offset. Unfortunately KVM does not expose a relative offset via
-// the API, so this is an approximation. We do this via an iterative algorithm.
-// This has the advantage that it can generally deal with highly variable
-// system call times and should converge on the correct offset.
-func (c *vCPU) setSystemTime() error {
- const (
- _MSR_IA32_TSC = 0x00000010
- calibrateTries = 10
- )
- registers := modelControlRegisters{
- nmsrs: 1,
- }
- registers.entries[0] = modelControlRegister{
- index: _MSR_IA32_TSC,
+// If mustSucceed is true, then this function panics on error.
+func (c *vCPU) getTSCFreq() (uintptr, error) {
+ rawFreq, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_TSC_KHZ,
+ 0 /* ignored */)
+ if errno != 0 {
+ return 0, errno
}
- target := uint64(^uint32(0))
- for done := 0; done < calibrateTries; {
- start := uint64(time.Rdtsc())
- registers.entries[0].data = start + target
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_MSRS,
- uintptr(unsafe.Pointer(&registers))); errno != 0 {
- return fmt.Errorf("error setting system time: %v", errno)
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result. So we only count attempts
- // within +/- 6.25% of our minimum as an attempt.
- end := uint64(time.Rdtsc())
- if end < start {
- continue // Totally bogus.
- }
- half := (end - start) / 2
- if half < target {
- target = half
- }
- if (half - target) < target/8 {
- done++
- }
+ return rawFreq, nil
+}
+
+// setTSCFreq sets the TSC frequency.
+func (c *vCPU) setTSCFreq(freq uintptr) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_TSC_KHZ,
+ freq /* khz */); errno != 0 {
+ return fmt.Errorf("error setting TSC frequency: %v", errno)
}
return nil
}
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
+// setTSC sets the TSC value.
+func (c *vCPU) setTSC(value uint64) error {
+ const _MSR_IA32_TSC = 0x00000010
+ registers := modelControlRegisters{
+ nmsrs: 1,
}
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
+ registers.entries[0].index = _MSR_IA32_TSC
+ registers.entries[0].data = value
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
+ _KVM_SET_MSRS,
+ uintptr(unsafe.Pointer(&registers))); errno != 0 {
+ return fmt.Errorf("error setting tsc: %v", errno)
}
return nil
}
// setUserRegisters sets user registers in the vCPU.
-func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+//
+//go:nosplit
+func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno {
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_REGS,
uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return fmt.Errorf("error setting user registers: %v", errno)
+ return errno
}
- return nil
+ return 0
}
// getUserRegisters reloads user registers in the vCPU.
@@ -175,3 +146,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
}
return nil
}
+
+// getSystemRegisters sets system registers.
+//
+//go:nosplit
+func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 9db171af9..54837f20c 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -19,6 +19,7 @@ package kvm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -48,6 +49,18 @@ const (
poolPCIDs = 8
)
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+}
+
// Get all read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
var rdonlyRegions []region
@@ -100,19 +113,6 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return phyRegions
}
-// dropPageTables drops cached page table entries.
-func (m *machine) dropPageTables(pt *pagetables.PageTables) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- // Clear from all PCIDs.
- for _, c := range m.vCPUsByID {
- if c.PCIDs != nil {
- c.PCIDs.Drop(pt)
- }
- }
-}
-
// nonCanonical generates a canonical address return.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 537419657..a163f956d 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -191,42 +191,6 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
return nil
}
-// setCPUID sets the CPUID to be used by the guest.
-func (c *vCPU) setCPUID() error {
- return nil
-}
-
-// setSystemTime sets the TSC for the vCPU.
-func (c *vCPU) setSystemTime() error {
- return nil
-}
-
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
- }
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
- }
-
- return nil
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
// Check for canonical addresses.
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 607c82156..1d6ca245a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -143,3 +143,29 @@ func (c *vCPU) waitUntilNot(state uint32) {
panic("futex wait error")
}
}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 530e779b0..dcfe839a7 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/hostmm"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,6 +53,10 @@ type Platform interface {
// can reliably return ErrContextCPUPreempted.
DetectsCPUPreemption() bool
+ // HaveGlobalMemoryBarrier returns true if the GlobalMemoryBarrier method
+ // is supported.
+ HaveGlobalMemoryBarrier() bool
+
// MapUnit returns the alignment used for optional mappings into this
// platform's AddressSpaces. Higher values indicate lower per-page costs
// for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
@@ -97,6 +102,15 @@ type Platform interface {
// called.
PreemptAllCPUs() error
+ // GlobalMemoryBarrier blocks until all threads running application code
+ // (via Context.Switch) and all task goroutines "have passed through a
+ // state where all memory accesses to user-space addresses match program
+ // order between entry to and return from [GlobalMemoryBarrier]", as for
+ // membarrier(2).
+ //
+ // Preconditions: HaveGlobalMemoryBarrier() == true.
+ GlobalMemoryBarrier() error
+
// SyscallFilters returns syscalls made exclusively by this platform.
SyscallFilters() seccomp.SyscallRules
}
@@ -115,6 +129,43 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
panic("This platform does not support CPU preemption detection")
}
+// UseHostGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking equivalent functionality on the
+// host.
+type UseHostGlobalMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ return hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) GlobalMemoryBarrier() error {
+ return hostmm.GlobalMemoryBarrier()
+}
+
+// UseHostProcessMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking a process-local memory barrier.
+// This is faster than UseHostGlobalMemoryBarrier, but is only appropriate for
+// platforms for which application code executes while using the sentry's
+// mm_struct.
+type UseHostProcessMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ // Fall back to a global memory barrier if a process-local one isn't
+ // available.
+ return hostmm.HaveProcessMemoryBarrier() || hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) GlobalMemoryBarrier() error {
+ if hostmm.HaveProcessMemoryBarrier() {
+ return hostmm.ProcessMemoryBarrier()
+ }
+ return hostmm.GlobalMemoryBarrier()
+}
+
// MemoryManager represents an abstraction above the platform address space
// which manages memory mappings and their contents.
type MemoryManager interface {
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index b52d0fbd8..f56aa3b79 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -192,6 +192,7 @@ func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {}
type PTrace struct {
platform.MMapMinAddr
platform.NoCPUPreemptionDetection
+ platform.UseHostGlobalMemoryBarrier
}
// New returns a new ptrace-based implementation of the platform interface.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 9c6c2cf5c..00899273e 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -76,15 +76,41 @@ type KernelOpts struct {
type KernelArchState struct {
KernelOpts
+ // cpuEntries is array of kernelEntry for all cpus
+ cpuEntries []kernelEntry
+
// globalIDT is our set of interrupt gates.
- globalIDT idt64
+ globalIDT *idt64
}
-// CPUArchState contains CPU-specific arch state.
-type CPUArchState struct {
+// kernelEntry contains minimal CPU-specific arch state
+// that can be mapped at the upper of the address space.
+// Malicious APP might steal info from it via CPU bugs.
+type kernelEntry struct {
// stack is the stack used for interrupts on this CPU.
stack [256]byte
+ // scratch space for temporary usage.
+ scratch0 uint64
+
+ // stackTop is the top of the stack.
+ stackTop uint64
+
+ // cpuSelf is back reference to CPU.
+ cpuSelf *CPU
+
+ // kernelCR3 is the cr3 used for sentry kernel.
+ kernelCR3 uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
// errorCode is the error code from the last exception.
errorCode uintptr
@@ -97,11 +123,7 @@ type CPUArchState struct {
// exception.
errorType uintptr
- // gdt is the CPU's descriptor table.
- gdt descriptorTable
-
- // tss is the CPU's task state.
- tss TaskState64
+ *kernelEntry
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index 7fa43c2f5..d87b1fd00 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -36,12 +36,15 @@ func sysenter()
// This must be called prior to sysret/iret.
func swapgs()
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
// sysret returns to userspace from a system call.
//
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *arch.Registers) Vector
+func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *arch.Registers) Vector
+func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s
index 02df38331..f59747df3 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.s
+++ b/pkg/sentry/platform/ring0/entry_amd64.s
@@ -63,6 +63,15 @@
MOVQ offset+PTRACE_RSI(reg), SI; \
MOVQ offset+PTRACE_RDI(reg), DI;
+// WRITE_CR3() writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+#define WRITE_CR3() \
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+
// SWAP_GS swaps the kernel GS (CPU).
#define SWAP_GS() \
BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
@@ -75,15 +84,9 @@
#define SYSRET64() \
BYTE $0x48; BYTE $0x0f; BYTE $0x07;
-// LOAD_KERNEL_ADDRESS loads a kernel address.
-#define LOAD_KERNEL_ADDRESS(from, to) \
- MOVQ from, to; \
- ORQ ·KernelStartAddress(SB), to;
-
// LOAD_KERNEL_STACK loads the kernel stack.
-#define LOAD_KERNEL_STACK(from) \
- LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
- LEAQ CPU_STACK_TOP(SP), SP;
+#define LOAD_KERNEL_STACK(entry) \
+ MOVQ ENTRY_STACK_TOP(entry), SP;
// See kernel.go.
TEXT ·Halt(SB),NOSPLIT,$0
@@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0
SWAP_GS()
RET
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ ORQ ·KernelStartAddress(SB), AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
// See entry_amd64.go.
TEXT ·sysret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+ // save SP AX userCR3 on the kernel stack.
+ MOVQ CPU_ENTRY(BX), BX
+ LOAD_KERNEL_STACK(BX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_RAX(AX)
+ PUSHQ CX
+
// Restore user register state.
REGISTERS_LOAD(AX, 0)
MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
- MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+
+ // restore userCR3, AX, SP.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
+ POPQ SP // Restore SP.
SYSRET64()
// See entry_amd64.go.
TEXT ·iret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
// Build an IRET frame & restore state.
+ MOVQ CPU_ENTRY(BX), BX
LOAD_KERNEL_STACK(BX)
- MOVQ PTRACE_SS(AX), BX; PUSHQ BX
- MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
- MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
- MOVQ PTRACE_CS(AX), DI; PUSHQ DI
- MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
- REGISTERS_LOAD(AX, 0) // Restore most registers.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ PUSHQ PTRACE_SS(AX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_FLAGS(AX)
+ PUSHQ PTRACE_CS(AX)
+ PUSHQ PTRACE_RIP(AX)
+ PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack.
+ PUSHQ CX // Save userCR3 on kernel stack.
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
IRET()
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
// See iret, above.
- MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
- MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
- MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
- MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
- REGISTERS_LOAD(GS, CPU_REGISTERS)
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ CPU_REGISTERS+PTRACE_SS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RSP(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_CS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RIP(AX)
+ REGISTERS_LOAD(AX, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX
IRET()
// See entry_amd64.go.
TEXT ·Start(SB),NOSPLIT,$0
- LOAD_KERNEL_STACK(AX) // Set the stack.
PUSHQ $0x0 // Previous frame pointer.
MOVQ SP, BP // Set frame pointer.
PUSHQ AX // First argument (CPU).
@@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0
// See entry_amd64.go.
TEXT ·sysenter(SB),NOSPLIT,$0
- // Interrupts are always disabled while we're executing in kernel mode
- // and always enabled while executing in user mode. Therefore, we can
- // reliably look at the flags in R11 to determine where this syscall
- // was from.
- TESTL $_RFLAGS_IF, R11
+ // _RFLAGS_IOPL0 is always set in the user mode and it is never set in
+ // the kernel mode. See the comment of UserFlagsSet for more details.
+ TESTL $_RFLAGS_IOPL0, R11
JZ kernel
-
user:
SWAP_GS()
- XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
- MOVQ BX, PTRACE_RAX(AX) // Save everything else.
- MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ CX, PTRACE_RIP(AX)
MOVQ R11, PTRACE_FLAGS(AX)
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ SP, PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value.
+ MOVQ CX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ CX, PTRACE_ORIGRAX(AX)
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks.
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
// Return to the kernel, where the frame is:
//
- // vector (sp+24)
+ // vector (sp+32)
+ // userCR3 (sp+24)
// regs (sp+16)
// cpu (sp+8)
// vcpu.Switch (sp+0)
//
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ $Syscall, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ $Syscall, 32(SP) // Output vector.
RET
kernel:
// We can't restore the original stack, but we can access the registers
// in the CPU state directly. No need for temporary juggling.
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ AX, ENTRY_SCRATCH0(GS)
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ AX // First argument (vCPU).
CALL ·kernelSyscall(SB) // Call the trampoline.
POPQ AX // Pop vCPU.
@@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0
// ERROR_CODE (sp+8)
// VECTOR (sp+0)
//
- TESTL $_RFLAGS_IF, 32(SP)
+ TESTL $_RFLAGS_IOPL0, 32(SP)
JZ kernel
user:
SWAP_GS()
ADDQ $-8, SP // Adjust for flags.
MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ PUSHQ AX // Save user AX on stack.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ POPQ BX // Restore original AX.
MOVQ BX, PTRACE_RAX(AX) // Save it.
MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
@@ -249,34 +299,36 @@ user:
MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
// Copy out and return.
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
MOVQ 0(SP), BX // Load vector.
MOVQ 8(SP), CX // Load error code.
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
- MOVQ BX, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(AX) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
+ MOVQ BX, 32(SP) // Output vector.
RET
kernel:
// As per above, we can save directly.
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+ PUSHQ AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ POPQ BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX)
// Set the error code and adjust the stack.
- MOVQ 8(SP), AX // Load the error code.
- MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 8(SP), BX // Load the error code.
+ MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
MOVQ 0(SP), BX // BX contains the vector.
- ADDQ $48, SP // Drop the exception frame.
// Call the exception trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ BX // Second argument (vector).
PUSHQ AX // First argument (vCPU).
CALL ·kernelException(SB) // Call the trampoline.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 5f63cbd45..274576e2d 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -47,8 +47,9 @@
#define SCTLR_C 1 << 2
#define SCTLR_I 1 << 12
#define SCTLR_UCT 1 << 15
+#define SCTLR_UCI 1 << 26
-#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT | SCTLR_UCI)
// cntkctl_el1: counter-timer kernel control register el1.
#define CNTKCTL_EL0PCTEN 1 << 0
@@ -296,9 +297,7 @@
LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
MOVD RSV_REG, RSP; \
- WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
- ISB $15; \
- DSB $15;
+ WORD $0xd538d092; //MRS TPIDR_EL1, R18
// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
#define SWITCH_TO_APP_PAGETABLE(from) \
@@ -342,6 +341,8 @@
ADD $16, RSP, RSP; \
MOVD RSV_REG, PTRACE_R18(R20); \
MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MRS TPIDR_EL0, R3; \
+ MOVD R3, PTRACE_TLS(R20); \
WORD $0xd5384003; \ // MRS SPSR_EL1, R3
MOVD R3, PTRACE_PSTATE(R20); \
MRS ELR_EL1, R3; \
@@ -354,6 +355,8 @@
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ MRS TPIDR_EL0, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_TLS(RSV_REG); \
WORD $0xd5384004; \ // MRS SPSR_EL1, R4
MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
MRS ELR_EL1, R4; \
@@ -377,8 +380,6 @@ TEXT ·Halt(SB),NOSPLIT,$0
BNE mmio_exit
MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG)
- // Flush dcache.
- WORD $0xd5087e52 // DC CISW
mmio_exit:
// Disable fpsimd.
WORD $0xd5381041 // MRS CPACR_EL1, R1
@@ -396,9 +397,6 @@ mmio_exit:
MRS VBAR_EL1, R9
MOVD R0, 0x0(R9)
- // Flush dcahce.
- WORD $0xd5087e52 // DC CISW
-
RET
// HaltAndResume halts execution and point the pointer to the resume function.
@@ -435,6 +433,8 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MRS TPIDR_EL1, RSV_REG
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+ MRS TPIDR_EL0, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_TLS(RSV_REG)
WORD $0xd5384003 // MRS SPSR_EL1, R3
MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
@@ -461,8 +461,18 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MOVD PTRACE_PSTATE(RSV_REG_APP), R1
WORD $0xd5184001 //MSR R1, SPSR_EL1
+ // need use kernel space address to excute below code, since
+ // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's
+ // ASID.
+ WORD $0x10000061 // ADR R1, do_exit_to_el0
+ ORR $0xffff000000000000, R1, R1
+ JMP (R1)
+
+do_exit_to_el0:
// RSV_REG & RSV_REG_APP will be loaded at the end.
REGISTERS_LOAD(RSV_REG_APP, 0)
+ MOVD PTRACE_TLS(RSV_REG_APP), RSV_REG
+ MSR RSV_REG, TPIDR_EL0
// switch to user pagetable.
MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
@@ -505,8 +515,6 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
- // Flush dcache.
- WORD $0xd5087e52 // DC CISW
// Init.
MOVD $SCTLR_EL1_DEFAULT, R1
MSR R1, SCTLR_EL1
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 549f3d228..9742308d8 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,7 +24,10 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
- visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
+ visibility = [
+ "//pkg/sentry/platform/kvm:__pkg__",
+ "//pkg/sentry/platform/ring0:__pkg__",
+ ],
deps = [
"//pkg/cpuid",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 021693791..264be23d3 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -19,8 +19,8 @@ package ring0
// N.B. that constraints on KernelOpts must be satisfied.
//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts) {
- k.init(opts)
+func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
+ k.init(opts, maxCPUs)
}
// Halt halts execution.
@@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) {
// kernelSyscall is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) {
// kernelException is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) {
// Init initializes a new CPU.
//
// Init allows embedding in other objects.
-func (c *CPU) Init(k *Kernel, hooks Hooks) {
- c.self = c // Set self reference.
- c.kernel = k // Set kernel reference.
- c.init() // Perform architectural init.
+func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init(cpuID) // Perform architectural init.
// Require hooks.
if hooks != nil {
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index d37981dbf..3a9dff4cc 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -18,13 +18,42 @@ package ring0
import (
"encoding/binary"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/usermem"
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
+ entrySize := reflect.TypeOf(kernelEntry{}).Size()
+ var (
+ entries []kernelEntry
+ padding = 1
+ )
+ for {
+ entries = make([]kernelEntry, maxCPUs+padding-1)
+ totalSize := entrySize * uintptr(maxCPUs+padding-1)
+ addr := reflect.ValueOf(&entries[0]).Pointer()
+ if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize {
+ // The runtime forces power-of-2 alignment for allocations, and we are therefore
+ // safe once the first address is aligned and the chunk is at least a full page.
+ break
+ }
+ padding = padding << 1
+ }
+ k.cpuEntries = entries
+
+ k.globalIDT = &idt64{}
+ if reflect.TypeOf(idt64{}).Size() != usermem.PageSize {
+ panic("Size of globalIDT should be PageSize")
+ }
+ if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 {
+ panic("Allocated globalIDT should be page aligned")
+ }
+
// Setup the IDT, which is uniform.
for v, handler := range handlers {
// Allow Breakpoint and Overflow to be called from all
@@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) {
}
}
+func (k *Kernel) EntryRegions() map[uintptr]uintptr {
+ regions := make(map[uintptr]uintptr)
+
+ addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer()
+ size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries))
+ end, _ := usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ addr = reflect.ValueOf(k.globalIDT).Pointer()
+ size = reflect.TypeOf(idt64{}).Size()
+ end, _ = usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ return regions
+}
+
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
+ c.kernelEntry = &c.kernel.cpuEntries[cpuID]
+ c.cpuSelf = c
// Null segment.
c.gdt[0].setNull()
@@ -65,6 +112,7 @@ func (c *CPU) init() {
// Set the kernel stack pointer in the TSS (virtual address).
stackAddr := c.StackTop()
+ c.stackTop = stackAddr
c.tss.rsp0Lo = uint32(stackAddr)
c.tss.rsp0Hi = uint32(stackAddr >> 32)
c.tss.ist1Lo = uint32(stackAddr)
@@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
- kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+ c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID))
// Sanitize registers.
regs := switchOpts.Registers
@@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
- jumpToKernel() // Switch to upper half.
- writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
- vector = iret(c, regs)
+ vector = iret(c, regs, uintptr(userCR3))
} else {
- vector = sysret(c, regs)
+ vector = sysret(c, regs, uintptr(userCR3))
}
- writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
- jumpToUser() // Return to lower half.
SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
@@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
//go:nosplit
func start(c *CPU) {
// Save per-cpu & FS segment.
- WriteGS(kernelAddr(c))
+ WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
// Initialize floating point.
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index 14774c5db..b294ccc7c 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,13 +25,13 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
}
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
// Set the kernel stack pointer(virtual address).
c.registers.Sp = uint64(c.StackTop())
@@ -64,11 +64,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate |= UserFlagsSet
LoadFloatingPoint(switchOpts.FloatingPointState)
- SetTLS(regs.TPIDR_EL0)
kernelExitToEl0()
- regs.TPIDR_EL0 = GetTLS()
SaveFloatingPoint(switchOpts.FloatingPointState)
vector = c.vecCode
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
index ca968a036..0ec5c3bc5 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.go
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -61,21 +61,9 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
-// writeCR3 writes the CR3 value.
-func writeCR3(phys uintptr)
-
-// readCR3 reads the current CR3 value.
-func readCR3() uintptr
-
// readCR2 reads the current CR2 value.
func readCR2() uintptr
-// jumpToKernel jumps to the kernel version of the current RIP.
-func jumpToKernel()
-
-// jumpToUser jumps to the user version of the current RIP.
-func jumpToUser()
-
// fninit initializes the floating point unit.
func fninit()
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
index 75d742750..2fe83568a 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.s
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30; // WRMSR
RET
-// jumpToUser changes execution to the user address.
-//
-// This works by changing the return value to the user version.
-TEXT ·jumpToUser(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- NOTQ BX
- ANDQ BX, SP // Switch the stack.
- ANDQ BX, BP // Switch the frame pointer.
- ANDQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// jumpToKernel changes execution to the kernel address space.
-//
-// This works by changing the return value to the kernel version.
-TEXT ·jumpToKernel(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- ORQ BX, SP // Switch the stack.
- ORQ BX, BP // Switch the frame pointer.
- ORQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// writeCR3 writes the given CR3 value.
-//
-// The code corresponds to:
-//
-// mov %rax, %cr3
-//
-TEXT ·writeCR3(SB),NOSPLIT,$0-8
- MOVQ cr3+0(FP), AX
- BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
- RET
-
-// readCR3 reads the current CR3 value.
-//
-// The code corresponds to:
-//
-// mov %cr3, %rax
-//
-TEXT ·readCR3(SB),NOSPLIT,$0-8
- BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
- MOVQ AX, ret+0(FP)
- RET
-
// readCR2 reads the current CR2 value.
//
// The code corresponds to:
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index b8ab120a0..ca4075b09 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -30,14 +30,21 @@ func Emit(w io.Writer) {
c := &CPU{}
fmt.Fprintf(w, "\n// CPU offsets.\n")
- fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
- fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+
+ e := &kernelEntry{}
+ fmt.Fprintf(w, "\n// CPU entry offsets.\n")
+ fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
+ fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0)
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 1d86b4bcf..45eba960d 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -125,4 +125,5 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_TLS 0x%02x\n", reflect.ValueOf(&p.TPIDR_EL0).Pointer()-reflect.ValueOf(p).Pointer())
}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 6409d1d91..520161755 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -78,7 +78,7 @@ const (
const (
executeDisable = xn
- optionMask = 0xfff | 0xfff<<48
+ optionMask = 0xfff | 0xffff<<48
protDefault = accessed | shared
)
@@ -188,7 +188,7 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) {
v |= mtNormal
} else {
v = v &^ user
- v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress.
+ v |= mtNormal
}
atomic.StoreUintptr((*uintptr)(p), v)
}
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 9da0ea685..34fbc1c35 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -39,7 +39,9 @@ const (
_RFLAGS_AC = 1 << 18
_RFLAGS_NT = 1 << 14
- _RFLAGS_IOPL = 3 << 12
+ _RFLAGS_IOPL0 = 1 << 12
+ _RFLAGS_IOPL1 = 1 << 13
+ _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1
_RFLAGS_DF = 1 << 10
_RFLAGS_IF = 1 << 9
_RFLAGS_STEP = 1 << 8
@@ -67,15 +69,45 @@ const (
KernelFlagsSet = _RFLAGS_RESERVED
// UserFlagsSet are always set in userspace.
- UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF
+ //
+ // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege
+ // level. The Current Privilege Level (CPL) of the task must be less
+ // than or equal to the IOPL in order for the task or program to access
+ // I/O ports.
+ //
+ // Here, _RFLAGS_IOPL0 is used only to determine whether the task is
+ // running in the kernel or userspace mode. In the user mode, the CPL is
+ // always 3 and it doesn't matter what IOPL is set if it is bellow CPL.
+ //
+ // We need to have one bit which will be always different in user and
+ // kernel modes. And we have to remember that even though we have
+ // KernelFlagsClear, we still can see some of these flags in the kernel
+ // mode. This can happen when the goruntime switches on a goroutine
+ // which has been saved in the host mode. On restore, the popf
+ // instruction is used to restore flags and this means that all flags
+ // what the goroutine has in the host mode will be restored in the
+ // kernel mode.
+ //
+ // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set
+ // it in the user mode. So if this flag is set, the task is running in
+ // the user mode and if it isn't set, the task is running in the kernel
+ // mode.
+ UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0
// KernelFlagsClear should always be clear in the kernel.
KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT
// UserFlagsClear are always cleared in userspace.
- UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL
+ UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1
)
+// IsKernelFlags returns true if rflags coresponds to the kernel mode.
+//
+// go:nosplit
+func IsKernelFlags(rflags uint64) bool {
+ return rflags&_RFLAGS_IOPL0 == 0
+}
+
// Vector is an exception vector.
type Vector uintptr
@@ -104,7 +136,7 @@ const (
VirtualizationException
SecurityException = 0x1e
SyscallInt80 = 0x80
- _NR_INTERRUPTS = SyscallInt80 + 1
+ _NR_INTERRUPTS = 0x100
)
// System call vectors.
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 94cb80437..904a12e38 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -147,10 +147,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
case stack.FilterTable:
table = stack.EmptyFilterTable()
case stack.NATTable:
- if ipv6 {
- nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)")
- return syserr.ErrInvalidArgument
- }
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index e3b108e93..0e14447fe 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -47,6 +47,9 @@ func init() {
registerTargetMaker(&redirectTargetMaker{
NetworkProtocol: header.IPv4ProtocolNumber,
})
+ registerTargetMaker(&nfNATTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
}
type standardTargetMaker struct {
@@ -194,11 +197,9 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte {
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
- if rt.RangeProtoSpecified {
- xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
- }
- xt.NfRange.RangeIPV4.MinPort = htons(rt.MinPort)
- xt.NfRange.RangeIPV4.MaxPort = htons(rt.MaxPort)
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
@@ -231,23 +232,103 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
// Also check if we need to map ports or IP.
// For now, redirect target only supports destination port change.
// Port range and IP range are not supported yet.
- if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags)
return nil, syserr.ErrInvalidArgument
}
- target.RangeProtoSpecified = true
-
- target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
- target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
// TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
}
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
+}
+
+type nfNATTarget struct {
+ Target linux.XTEntryTarget
+ Range linux.NFNATRange
+}
+
+const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
- target.MinPort = ntohs(nfRange.RangeIPV4.MinPort)
- target.MaxPort = ntohs(nfRange.RangeIPV4.MaxPort)
+type nfNATTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *nfNATTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
+ nt := nfNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: nfNATMarhsalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
+ },
+ }
+ copy(nt.Target.Name[:], stack.RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.Addr)
+ copy(nt.Range.MaxAddr[:], rt.Addr)
+
+ nt.Range.MinProto = htons(rt.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarhsalledSize)
+ return binary.Marshal(ret, usermem.ByteOrder, nt)
+}
+
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if size := nfNATMarhsalledSize; len(buf) < size {
+ nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("nfNATTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
+ binary.Unmarshal(buf, usermem.ByteOrder, &natRange)
+
+ // We don't support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("nfNATTargetMaker: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("nfNATTargetMaker: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/3549): Check for other flags.
+ // For now, redirect target only supports destination change.
+ if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ }
return &target, nil
}
@@ -308,7 +389,7 @@ func (jt *JumpTarget) ID() stack.TargetID {
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 335822c0e..87e30d742 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1512,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &vP, nil
case linux.IP6T_ORIGINAL_DST:
- // TODO(gvisor.dev/issue/170): ip6tables.
- return nil, syserr.ErrInvalidArgument
+ if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
if outLen < linux.SizeOfIPTGetinfo {
@@ -1555,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
return &entries, nil
+ case linux.IP6T_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 75752b2e6..a2e441448 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -21,6 +21,7 @@ go_library(
"sys_identity.go",
"sys_inotify.go",
"sys_lseek.go",
+ "sys_membarrier.go",
"sys_mempolicy.go",
"sys_mmap.go",
"sys_mount.go",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 5f26697d2..9c9def7cd 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -376,7 +376,7 @@ var AMD64 = &kernel.SyscallTable{
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
322: syscalls.Supported("execveat", Execveat),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls implemented after 325 are "backports" from versions
@@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{
96: syscalls.Supported("set_tid_address", SetTidAddress),
97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 99: syscalls.Supported("set_robust_list", SetRobustList),
+ 100: syscalls.Supported("get_robust_list", GetRobustList),
101: syscalls.Supported("nanosleep", Nanosleep),
102: syscalls.Supported("getitimer", Getitimer),
103: syscalls.Supported("setitimer", Setitimer),
@@ -695,7 +695,7 @@ var ARM64 = &kernel.SyscallTable{
280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
281: syscalls.Supported("execveat", Execveat),
282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls after 284 are "backports" from versions of Linux after 4.4.
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 98331eb3c..519066a47 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -84,6 +84,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
}
rel = f.Dirent
if !fs.IsDir(rel.Inode.StableAttr) {
+ f.DecRef(t)
return syserror.ENOTDIR
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go
new file mode 100644
index 000000000..63ee5d435
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_membarrier.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Membarrier implements syscall membarrier(2).
+func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ cmd := args[0].Int()
+ flags := args[1].Uint()
+
+ switch cmd {
+ case linux.MEMBARRIER_CMD_QUERY:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ var supportedCommands uintptr
+ if t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ supportedCommands |= linux.MEMBARRIER_CMD_GLOBAL |
+ linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED
+ }
+ if t.RSeqAvailable() {
+ supportedCommands |= linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ
+ }
+ return supportedCommands, nil, nil
+ case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier()
+ case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ // no-op
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierPrivate()
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ:
+ if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.MemoryManager().IsMembarrierRSeqEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ // MEMBARRIER_CMD_FLAG_CPU and cpu_id are ignored since we don't have
+ // the ability to preempt specific CPUs.
+ return 0, nil, t.Kernel().Platform.PreemptAllCPUs()
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierRSeq()
+ return 0, nil, nil
+ default:
+ // Probably a command we don't implement.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 674d341b6..6320593f0 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -26,8 +26,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
addr := args[0].Pointer()
mf := t.Kernel().MemoryFile()
- mf.UpdateUsage()
- _, totalUsage := usage.MemoryAccounting.Copy()
+ mfUsage, err := mf.TotalUsage()
+ if err != nil {
+ return 0, nil, err
+ }
+ memStats, _ := usage.MemoryAccounting.Copy()
+ totalUsage := mfUsage + memStats.Mapped
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
memFree := totalSize - totalUsage
if memFree > totalSize {
@@ -37,12 +41,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Only a subset of the fields in sysinfo_t make sense to return.
si := linux.Sysinfo{
- Procs: uint16(len(t.PIDNamespace().Tasks())),
+ Procs: uint16(t.Kernel().TaskSet().Root.NumTasks()),
Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
TotalRAM: totalSize,
FreeRAM: memFree,
Unit: 1,
}
- _, err := si.CopyOut(t, addr)
+ _, err = si.CopyOut(t, addr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index 066ee0863..c8ce2aabc 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -110,8 +110,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user
}
// Load the new TaskContext.
- mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
- defer mntns.DecRef(t)
+ mntns := t.MountNamespaceVFS2()
wd := t.FSContext().WorkingDirectoryVFS2()
defer wd.DecRef(t)
remainingTraversals := uint(linux.MaxSymlinkTraversals)
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 0df3bd449..c50fd97eb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -163,6 +163,7 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
s.Table[5] = syscalls.Supported("setxattr", SetXattr)
s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
@@ -200,6 +201,7 @@ func Override() {
s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[45] = syscalls.Supported("truncate", Truncate)
s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
s.Table[48] = syscalls.Supported("faccessat", Faccessat)
s.Table[49] = syscalls.Supported("chdir", Chdir)
s.Table[50] = syscalls.Supported("fchdir", Fchdir)
@@ -221,12 +223,14 @@ func Override() {
s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
s.Table[69] = syscalls.Supported("preadv", Preadv)
s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[71] = syscalls.Supported("sendfile", Sendfile)
s.Table[72] = syscalls.Supported("pselect", Pselect)
s.Table[73] = syscalls.Supported("ppoll", Ppoll)
s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
s.Table[76] = syscalls.Supported("splice", Splice)
s.Table[77] = syscalls.Supported("tee", Tee)
s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[79] = syscalls.Supported("newfstatat", Newfstatat)
s.Table[80] = syscalls.Supported("fstat", Fstat)
s.Table[81] = syscalls.Supported("sync", Sync)
s.Table[82] = syscalls.Supported("fsync", Fsync)
@@ -251,8 +255,10 @@ func Override() {
s.Table[210] = syscalls.Supported("shutdown", Shutdown)
s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[213] = syscalls.Supported("readahead", Readahead)
s.Table[221] = syscalls.Supported("execve", Execve)
s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil)
s.Table[242] = syscalls.Supported("accept4", Accept4)
s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
s.Table[267] = syscalls.Supported("syncfs", Syncfs)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 8093ca55c..c855608db 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -92,7 +92,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index bdfd3ca8f..7ad0eaf86 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -61,11 +61,14 @@ func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *au
panic("cannot instaniate an anon filesystem")
}
-// Name implemenents FilesystemType.Name.
+// Name implements FilesystemType.Name.
func (anonFilesystemType) Name() string {
return "none"
}
+// Release implemenents FilesystemType.Release.
+func (anonFilesystemType) Release(ctx context.Context) {}
+
// anonFilesystem is the implementation of FilesystemImpl that backs
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
//
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 1eba0270f..183957ad8 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -827,7 +827,7 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
}
// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/etc.
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
type FileReadWriteSeeker struct {
FD *FileDescription
Ctx context.Context
@@ -835,11 +835,18 @@ type FileReadWriteSeeker struct {
WOpts WriteOptions
}
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
// Read implements io.ReadWriteSeeker.Read.
func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
dst := usermem.BytesIOSequence(p)
- ret, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(ret), err
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
}
// Seek implements io.ReadWriteSeeker.Seek.
@@ -847,9 +854,16 @@ func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
return f.FD.Seek(f.Ctx, offset, int32(whence))
}
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
// Write implements io.ReadWriteSeeker.Write.
func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
buf := usermem.BytesIOSequence(p)
- ret, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(ret), err
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
}
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index bc19db1d5..9d54cc4ed 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -33,6 +33,9 @@ type FilesystemType interface {
// Name returns the name of this FilesystemType.
Name() string
+
+ // Release releases all resources held by this FilesystemType.
+ Release(ctx context.Context)
}
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index dfc3ae6c0..78f115bfa 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -46,8 +46,9 @@ import (
// +stateify savable
type Mount struct {
// vfs, fs, root are immutable. References are held on fs and root.
+ // Note that for a disconnected mount, root may be nil.
//
- // Invariant: root belongs to fs.
+ // Invariant: if not nil, root belongs to fs.
vfs *VirtualFilesystem
fs *Filesystem
root *Dentry
@@ -498,7 +499,9 @@ func (mnt *Mount) DecRef(ctx context.Context) {
mnt.vfs.mounts.seq.EndWrite()
mnt.vfs.mountMu.Unlock()
}
- mnt.root.DecRef(ctx)
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
+ }
mnt.fs.DecRef(ctx)
if vd.Ok() {
vd.DecRef(ctx)
@@ -724,14 +727,12 @@ func (mnt *Mount) Root() *Dentry {
return mnt.root
}
-// Root returns mntns' root. A reference is taken on the returned
-// VirtualDentry.
+// Root returns mntns' root. It does not take a reference on the returned Dentry.
func (mntns *MountNamespace) Root() VirtualDentry {
vd := VirtualDentry{
mount: mntns.root,
dentry: mntns.root.root,
}
- vd.IncRef()
return vd
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 5bd756ea5..38d2701d2 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -158,6 +158,16 @@ func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
return nil
}
+// Release drops references on filesystem objects held by vfs.
+//
+// Precondition: This must be called after VFS.Init() has succeeded.
+func (vfs *VirtualFilesystem) Release(ctx context.Context) {
+ vfs.anonMount.DecRef(ctx)
+ for _, fst := range vfs.fsTypes {
+ fst.fsType.Release(ctx)
+ }
+}
+
// PathOperation specifies the path operated on by a VFS method.
//
// PathOperation is passed to VFS methods by pointer to reduce memory copying:
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index ea0c5413d..8db70a700 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -84,8 +84,8 @@ type VectorisedView struct {
size int
}
-// NewVectorisedView creates a new vectorised view from an already-allocated slice
-// of View and sets its size.
+// NewVectorisedView creates a new vectorised view from an already-allocated
+// slice of View and sets its size.
func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
@@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) {
}
// Clone returns a clone of this VectorisedView.
-// If the buffer argument is large enough to contain all the Views of this VectorisedView,
-// the method will avoid allocations and use the buffer to store the Views of the clone.
+// If the buffer argument is large enough to contain all the Views of this
+// VectorisedView, the method will avoid allocations and use the buffer to
+// store the Views of the clone.
func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
@@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) {
return newFirst, true
}
-// Size returns the size in bytes of the entire content stored in the vectorised view.
+// Size returns the size in bytes of the entire content stored in the
+// vectorised view.
func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
+ return vv.ToOwnedView()
+}
+
+// ToOwnedView returns a single view containing the content of the vectorised
+// view that vv does not own.
+func (vv *VectorisedView) ToOwnedView() View {
u := make([]byte, 0, vv.size)
for _, v := range vv.views {
u = append(u, v...)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 19627fa9b..d4d785cca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
+ }
+ }
+}
+
+// IPv4Options returns a checker that checks the options in an IPv4 packet.
+func IPv4Options(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ options := ip.Options()
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(options) == 0 {
+ return
+ }
+ if diff := cmp.Diff(want, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
}
}
}
@@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker {
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
@@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker {
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
@@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
+// not contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
@@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker {
}
}
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
+// and potentially additional ICMPv4 header fields.
func ICMPv4(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
@@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index eaface8cb..95ade0e5c 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
-// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// IsMulticastEthernetAddress returns true if the address is a multicast
+// ethernet address.
+func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
+}
+
+// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
- // Must be of the right length.
if len(addr) != EthernetAddressSize {
return false
}
- // Must not be unspecified.
if addr == unspecifiedEthernetAddress {
return false
}
- // Must not be a multicast.
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
- // addr is a valid unicast ethernet address.
return true
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 14413f2ce..3bc8b2b21 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
}
}
+func TestIsMulticastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ true,
+ },
+ {
+ "Unicast",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsMulticastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
+
func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c00bcadfb..504408878 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
-// SetPointer sets the pointer field in a Parameter error packet.
-// This is the first byte of the type specific data field.
-func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
-
-// SetTypeSpecific sets the full 32 bit type specific data field.
-func (b ICMPv4) SetTypeSpecific(val uint32) {
- binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val)
-}
-
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4eb5abd79..4303fc5d5 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -49,11 +49,6 @@ const (
// neighbor advertisement packet.
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
- // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
- // including the NDP Target Link Layer option for an Ethernet
- // address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
-
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
@@ -156,9 +151,14 @@ const (
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
+ // ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
- ICMPv6UnknownHeader ICMPv6Code = 1
- ICMPv6UnknownOption ICMPv6Code = 2
+
+ // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
+ ICMPv6UnknownHeader ICMPv6Code = 1
+
+ // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
+ ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
@@ -177,7 +177,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
-// SetTypeSpecific sets the full 32 bit type specific data field.
+// TypeSpecific returns the type specific data field.
+func (b ICMPv6) TypeSpecific() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
+}
+
+// SetTypeSpecific sets the type specific data field.
func (b ICMPv6) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index b07d9991d..4c6e4be64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,10 +16,29 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// RFC 971 defines the fields of the IPv4 header on page 11 using the following
+// diagram: ("Figure 4")
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Identification |Flags| Fragment Offset |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Time to Live | Protocol | Header Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Source Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Destination Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options | Padding |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
const (
versIHL = 0
tos = 1
@@ -33,6 +52,7 @@ const (
checksum = 10
srcAddr = 12
dstAddr = 16
+ options = 20
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -76,7 +96,8 @@ type IPv4Fields struct {
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
-// Always call IsValid() to validate an instance of IPv4 before using other methods.
+// Always call IsValid() to validate an instance of IPv4 before using other
+// methods.
type IPv4 []byte
const (
@@ -151,13 +172,44 @@ func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
- return int(b[versIHL] >> 4)
+ return int(b[versIHL] >> ipVersionShift)
}
+// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits
+// of the first byte, and is counted in multiples of 4 bytes.
+//
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// (...)
+// Version: 4 bits
+// The Version field indicates the format of the internet header. This
+// document describes version 4.
+//
+// IHL: 4 bits
+// Internet Header Length is the length of the internet header in 32
+// bit words, and thus points to the beginning of the data. Note that
+// the minimum value for a correct header is 5.
+//
+const (
+ ipVersionShift = 4
+ ipIHLMask = 0x0f
+ IPv4IHLStride = 4
+)
+
// HeaderLength returns the value of the "header length" field of the ipv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
- return (b[versIHL] & 0xf) * 4
+ return (b[versIHL] & ipIHLMask) * IPv4IHLStride
+}
+
+// SetHeaderLength sets the value of the "Internet Header Length" field.
+func (b IPv4) SetHeaderLength(hdrLen uint8) {
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
// ID returns the value of the identifier field of the ipv4 header.
@@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// Options returns a a buffer holding the options.
+func (b IPv4) Options() []byte {
+ hdrLen := b.HeaderLength()
+ return b[options:hdrLen:hdrLen]
+}
+
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.Protocol())
@@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
+// SetTTL sets the "Time to Live" field of the IPv4 header.
+func (b IPv4) SetTTL(v byte) {
+ b[ttl] = v
+}
+
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
@@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 {
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
- b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 0761a1807..c5d8a3456 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -34,6 +34,9 @@ const (
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
+
+ // IPv6FixedHeaderSize is the size of the fixed header.
+ IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -69,7 +72,7 @@ type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
- IPv6MinimumSize = 40
+ IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
@@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool {
return addr[0] != 0xff
}
+const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
- const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// IsSolicitedNodeAddr determines whether the address is a solicited-node
+// multicast address.
+func IsSolicitedNodeAddr(addr tcpip.Address) bool {
+ return solicitedNodeMulticastPrefix == addr[:len(addr)-3]
+}
+
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
//
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 3499d8399..583c2c5d3 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader bytes.Reader
+
+ // optionOffset is the number of bytes from the first byte of the
+ // options field to the beginning of the current option.
+ optionOffset uint32
+
+ // nextOptionOffset is the offset of the next option.
+ nextOptionOffset uint32
+}
+
+// OptionOffset returns the number of bytes parsed while processing the
+// option field of the current Extension Header.
+func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
+ return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
@@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// the options data, or an error occured.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
+ i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
@@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
+ i.nextOptionOffset = i.optionOffset + 1
continue
}
@@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
- // Special-case the variable length padding option to avoid a copy.
- if id == ipv6PadNExtHdrOptionIdentifier {
- // Do we have enough bytes in the reader for the PadN option?
- if n := i.reader.Len(); n < int(length) {
- // Reset the reader to effectively consume the remaining buffer.
- i.reader.Reset(nil)
-
- // We return the same error as if we failed to read a non-padding option
- // so consumers of this iterator don't need to differentiate between
- // padding and non-padding options.
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
- }
+ // Do we have enough bytes in the reader for the next option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
+ switch id {
+ case ipv6PadNExtHdrOptionIdentifier:
+ // Special-case the variable length padding option to avoid a copy.
if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
-
- // End processing of the PadN option and continue processing the buffer as
- // a new option.
continue
- }
-
- bytes := make([]byte, length)
- if n, err := io.ReadFull(&i.reader, bytes); err != nil {
- // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
- // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
- // Length field found in the option.
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ default:
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
-
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
-
- return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
@@ -382,6 +396,29 @@ type IPv6PayloadIterator struct {
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
+
+ // headerOffset is the offset of the beginning of the current extension
+ // header starting from the beginning of the fixed header.
+ headerOffset uint32
+
+ // parseOffset is the byte offset into the current extension header of the
+ // field we are currently examining. It can be added to the header offset
+ // if the absolute offset within the packet is required.
+ parseOffset uint32
+
+ // nextOffset is the offset of the next header.
+ nextOffset uint32
+}
+
+// HeaderOffset returns the offset to the start of the extension
+// header most recently processed.
+func (i IPv6PayloadIterator) HeaderOffset() uint32 {
+ return i.headerOffset
+}
+
+// ParseOffset returns the number of bytes successfully parsed.
+func (i IPv6PayloadIterator) ParseOffset() uint32 {
+ return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
@@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
nextHdrIdentifier: nextHdrIdentifier,
payload: payload.Clone(nil),
// We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
- reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ nextOffset: IPv6FixedHeaderSize,
}
}
@@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occured.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ i.headerOffset = i.nextOffset
+ i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
@@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
return IPv6RoutingExtHdr(bytes), false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
- // We ignore the returned bytes becauase we know the fragment extension
+ // We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
@@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
+ i.parseOffset++
var length uint8
length, err = i.reader.ReadByte()
i.payload.TrimFront(1)
+
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
@@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length = 0
}
+ // Make parseOffset point to the first byte of the Extension Header
+ // specific data.
+ i.parseOffset++
+
+ // length is in 8 byte chunks but doesn't include the first one.
+ // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
+ // in section 4.8 for new extension headers at the top of page 24.
+ // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
+ // units, not including the first 8 octets.
+ i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
+
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if bytes == nil {
bytes = make([]byte, bytesLen)
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
index b5540bf66..17a49d4fa 100644
--- a/pkg/tcpip/header/ipversion_test.go
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -22,7 +22,7 @@ import (
func TestIPv4(t *testing.T) {
b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
+ b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize})
const want = header.IPv4Version
if v := header.IPVersion(b); v != want {
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 522135557..5ca75c834 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -139,6 +139,7 @@ traverseExtensions:
// Returns true if the header was successfully parsed.
func UDP(pkt *stack.PacketBuffer) bool {
_, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
return ok
}
@@ -162,5 +163,6 @@ func TCP(pkt *stack.PacketBuffer) bool {
}
_, ok = pkt.TransportHeader().Consume(hdrLen)
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
return ok
}
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
new file mode 100644
index 000000000..9f31c1ffc
--- /dev/null
+++ b/pkg/tcpip/link/pipe/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = ["pipe.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..76f563811
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ capabilities: capabilities,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ linked: ep1,
+ capabilities: capabilities,
+ }
+ ep1.linked = ep2
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ capabilities stack.LinkEndpointCapabilities
+ linkAddr tcpip.LinkAddress
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ onWritePacket func(*stack.PacketBuffer)
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // The pipe endpoint will accept all multicast/broadcast link traffic and only
+ // unicast traffic destined to itself.
+ if len(e.linked.linkAddr) != 0 &&
+ r.RemoteLinkAddress != e.linked.linkAddr &&
+ r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
+ !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
+ return nil
+ }
+
+ e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index b6ddbe81e..f94491026 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) {
}
}
+// NICID returns the NIC ID of the device.
+//
+// Must only be called after the device has been attached to an endpoint.
+func (d *Device) NICID() tcpip.NICID {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+
+ if d.endpoint == nil {
+ panic("called NICID on a device that has not been attached")
+ }
+
+ return d.endpoint.nicID
+}
+
// SetIff services TUNSETIFF ioctl(2) request.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+//
+// Returns true if a new NIC was created; false if an existing one was attached.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
// Input validations.
@@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
prefix := "tun"
@@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return nil
+ return created, nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
- if nic, found := s.GetNICByName(name); found {
- endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
+ endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, false, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, nil
+ return endpoint, false, nil
}
}
@@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, nil
+ return endpoint, true, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, false, syserror.EINVAL
}
}
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b47a7be51..7df77c66e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -49,7 +49,6 @@ type endpoint struct {
enabled uint32
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
@@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 {
}
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
+ lmtu := e.nic.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
func (e *endpoint) Close() {
@@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ origSender := h.HardwareAddressSender()
+ r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
- packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.HardwareAddressTarget(), origSender)
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
@@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
+ NetProto: ProtocolNumber,
RemoteLinkAddress: remoteLinkAddr,
}
if len(r.RemoteLinkAddress) == 0 {
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index e247f06a4..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -44,5 +46,7 @@ go_test(
deps = [
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index e1909fab0..ed502a473 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,7 +13,7 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
@@ -25,12 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
- // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
- DefaultReassembleTimeout = 30 * time.Second
-
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
@@ -243,3 +241,78 @@ func (f *Fragmentation) releaseReassemblersLocked() {
f.release(r)
}
}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ innerMTU int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ innerMTU: innerMTU,
+ fragmentCount: fragmentCount,
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 189b223c5..d3c7d7f92 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -20,10 +20,16 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
@@ -96,7 +102,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
@@ -251,7 +257,7 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
@@ -275,7 +281,7 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
@@ -370,7 +376,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
_, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
@@ -381,3 +387,113 @@ func TestErrors(t *testing.T) {
})
}
}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ innerMTU int
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ innerMTU: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ innerMTU: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ innerMTU: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ innerMTU: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := fragPkt.Size(); got > test.innerMTU {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6861cfdaf..d436873b6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- tester testObject
+ testObject
mu struct {
sync.RWMutex
@@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
- return &t.tester
-}
-
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
@@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv4Addr
- nic.tester.dstAddr = remoteIPv4Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
@@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
@@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
@@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv6Addr
- nic.tester.dstAddr = remoteIPv6Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 0a7e98ed1..7fc12e229 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -28,12 +28,15 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 5c4f715d7..3407755ed 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,8 @@
package ipv4
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
// Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
+ headerChecksum := h.Checksum()
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
+ calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(headerChecksum)
+ if calculatedChecksum != headerChecksum {
+ // It's possible that a raw socket still expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
- // Make a copy of data before pkt gets sent to raw socket.
- // DeliverTransportPacket will take ownership of pkt.
- replyData := pkt.Data.Clone(nil)
- replyData.TrimFront(header.ICMPv4MinimumSize)
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+ replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
@@ -117,32 +117,49 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
defer r.Release()
- // Use the remote link address from the incoming packet.
- r.ResolveWith(remoteLinkAddr)
-
- // Prepare a reply packet.
- icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize)
- copy(icmpHdr, h)
- icmpHdr.SetType(header.ICMPv4EchoReply)
- icmpHdr.SetChecksum(0)
- icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0)))
- dataVV := buffer.View(icmpHdr).ToVectorisedView()
- dataVV.Append(replyData)
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket().
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: dataVV,
+ Data: replyVV,
})
- // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
- // information we will have to change this code to handle the ICMP header
- // no longer being in the data buffer.
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // Send out the reply packet.
+
+ // The checksum will be calculated so we don't need to do it here.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
- Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, replyPkt); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- sent := r.Stats().ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
@@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Assume any type we don't know about may be an error type.
return nil
}
- } else if transportHeader.IsEmpty() {
- return nil
}
// Now work out how much of the triggering packet we should return.
@@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv4MinimumProcessableDatagramSize {
mtu = header.IPv4MinimumProcessableDatagramSize
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
@@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
ReserveHeaderBytes: headerLen,
Data: payload,
})
+
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter := sent.DstUnreachable
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
- if err := r.WritePacket(
+ if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
+ TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 746cf520d..c5ac7b8b5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -18,6 +18,7 @@ package ipv4
import (
"fmt"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -30,6 +31,15 @@ import (
)
const (
+ // As per RFC 791 section 3.2:
+ // The current recommendation for the initial timer setting is 15 seconds.
+ // This may be changed as experience with this protocol accumulates.
+ //
+ // Considering that it is an old recommendation, we use the same reassembly
+ // timeout that linux defines, which is 30 seconds:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
+ reassembleTimeout = 30 * time.Second
+
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -56,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
@@ -77,7 +86,6 @@ type endpoint struct {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
}
@@ -168,21 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -190,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is already present in pkt.NetworkHeader.
-// pkt.TransportHeader may be set. mtu includes the IP header and options. This
-// does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.NetworkHeader().View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
-
- // Keep the length reserved for link-layer, we need to create fragments with
- // the same reserved length.
- reservedForLink := pkt.AvailableHeaderBytes()
-
- // Destroy the packet, pull all payloads out for fragmentation.
- transHeader, data := pkt.TransportHeader().View(), pkt.Data
-
- // Where possible, the first fragment that is sent has the same
- // number of bytes reserved for header as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- transFitsFirst := len(transHeader) <= innerMTU
-
- for i := 0; i < n; i++ {
- reserve := reservedForLink + int(ip.HeaderLength())
- if i == 0 && transFitsFirst {
- // Reserve for transport header if it's going to be put in the first
- // fragment.
- reserve += len(transHeader)
- }
- fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: reserve,
- })
- fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-
- // Copy data for the fragment.
- avail := innerMTU
-
- if n := len(transHeader); n > 0 {
- if n > avail {
- n = avail
- }
- if i == 0 && transFitsFirst {
- copy(fragPkt.TransportHeader().Push(n), transHeader)
- } else {
- fragPkt.Data.AppendView(transHeader[:n:n])
- }
- transHeader = transHeader[n:]
- avail -= n
- }
-
- if avail > 0 {
- n := data.Size()
- if n > avail {
- n = avail
- }
- data.ReadToVV(&fragPkt.Data, n)
- avail -= n
- }
-
- copied := uint16(innerMTU - avail)
-
- // Set lengths in header and calculate checksum.
- h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
- copy(h, ip)
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + copied)
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += copied
+// writePacketFragments fragments pkt and writes the results on the link
+// endpoint. The IP header must already present in the original packet. The mtu
+// is the maximum size of the packets.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
- // Send out the fragment.
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
r.Stats().IP.PacketsSent.Increment()
+ if !more {
+ break
+ }
}
+
return nil
}
@@ -303,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -329,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -345,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -377,8 +306,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -392,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -401,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -461,9 +394,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -566,7 +502,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -595,6 +537,13 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
}
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
e.mu.Lock()
@@ -625,11 +574,11 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
return addressEndpoint
}
-// AcquirePrimaryAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.AcquirePrimaryAddress(remoteAddr, allowExpired)
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
// PrimaryAddresses implements stack.AddressableEndpoint.
@@ -787,14 +736,36 @@ func calculateMTU(mtu uint32) uint32 {
return mtu - header.IPv4MinimumSize
}
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ return mtu
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
@@ -814,6 +785,29 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
}
}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
+ }
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 277560e35..9916d783f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,19 +16,24 @@ package ipv4_test
import (
"bytes"
+ "context"
"encoding/hex"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ options []byte
+ }{
+ {
+ name: "valid",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPv4Options(test.options),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
+ }
+}
+
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if i == 0 {
- got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
- // sourcePacketInfo does not have NetworkHeader added, simulate one.
- want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
- // Check that it kept the transport header in packet.TransportHeader if
- // it fits in the first fragment.
- if want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- }
if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
@@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
func TestFragmentation(t *testing.T) {
+ const ttl = 42
+
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
@@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) {
payloadViewsSizes []int
expectedFrags int
}{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
}
for _, ft := range fragTests {
@@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) {
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("got err = %s, want = nil", err)
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
if got := len(ep.WrittenPackets); got != ft.expectedFrags {
@@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) {
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
@@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) {
// TestFragmentationErrors checks that errors are returned from write packet
// correctly.
func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ expectedError := tcpip.ErrAborted
fragTests := []struct {
description string
mtu uint32
transportHeaderLength int
- payloadViewsSizes []int
- err *tcpip.Error
+ payloadSize int
allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
+ {
+ description: "No frag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on first frag MTU smaller than header",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.err {
- t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
+ if err != expectedError {
+ t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ }
})
}
}
@@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
+
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
@@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\x10\x00\x00\x02"
)
if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
}
return rt
}
@@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 97adbcbd4..a30437f02 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -41,6 +42,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4b4b483cc..a454f6c34 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -286,6 +286,17 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
+ // As per RFC 4861 section 7.1.1:
+ // A node MUST silently discard any received Neighbor Solicitation
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP source address is the unspecified address, the IP
+ // destination address is a solicited-node multicast address.
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ received.Invalid.Increment()
+ return
+ }
+
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
// route we end up with here has as its LocalAddress such a
@@ -322,18 +333,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
}
+ neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
})
- packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
+ na.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -350,7 +361,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.Invalid.Increment()
return
}
@@ -429,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := r.LocalAddress
@@ -445,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
defer r.Release()
- // Use the link address from the source of the original packet.
- r.ResolveWith(remoteLinkAddr)
-
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data,
@@ -614,18 +620,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
}
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
-
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-
- icmpV6FlagOffset = 4
- icmpV6OptOffset = 24
- icmpV6LengthOffset = 25
-)
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -635,31 +629,37 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- snaddr := header.SolicitedNodeAddr(addr)
-
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
// route here. Note, we would need the nicID to do this properly so the right
// NIC (associated to linkEP) is used to send the NDP NS message.
- r := &stack.Route{
+ r := stack.Route{
LocalAddress: localAddr,
- RemoteAddress: snaddr,
+ RemoteAddress: addr,
RemoteLinkAddress: remoteLinkAddr,
}
+
+ // If a remote address is not already known, then send a multicast
+ // solicitation since multicast addresses have a static mapping to link
+ // addresses.
if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ r.RemoteAddress = header.SolicitedNodeAddr(addr)
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
}
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ }
+ neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
})
- icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- icmpHdr.SetType(header.ICMPv6NeighborSolicit)
- copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
- icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
- icmpHdr[icmpV6LengthOffset] = 1
- copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
+ packet.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns.SetTargetAddress(addr)
+ ns.Options().Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
@@ -672,7 +672,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -690,6 +690,36 @@ type icmpReason interface {
isICMPReason()
}
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
type icmpReasonPortUnreachable struct{}
@@ -698,18 +728,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- stats := r.Stats().ICMP
- sent := stats.V6PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
- // TODO(b/164522993) There are exceptions to this rule.
+ // There are exceptions to this rule.
// See: point e.3) RFC 4443 section-2.4
//
// (e) An ICMPv6 error message MUST NOT be originated as a result of
@@ -727,7 +750,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any {
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return nil
}
@@ -757,11 +805,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// packet that caused the error) as possible without making
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv6MinimumMTU {
mtu = header.IPv6MinimumMTU
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
available := int(mtu) - headerLen
if available < header.IPv6MinimumSize {
return nil
@@ -780,12 +828,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
- counter := sent.DstUnreachable
- err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
- if err != nil {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
sent.Dropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 31370c1d4..3affcc4e4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,21 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -39,6 +43,9 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
@@ -50,6 +57,10 @@ type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
// Indicate that resolution for link layer addresses is required to send
// packets over this link. This is needed so the NIC knows to allocate a
@@ -105,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -123,10 +136,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -1219,19 +1228,22 @@ func TestLinkAddressRequest(t *testing.T) {
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectedLinkAddr tcpip.LinkAddress
+ expectedAddr tcpip.Address
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectLinkAddr: linkAddr1,
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectedLinkAddr: linkAddr1,
+ expectedAddr: lladdr0,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectLinkAddr: mcaddr,
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectedLinkAddr: mcaddr,
+ expectedAddr: snaddr,
},
}
@@ -1254,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
+ if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ }
+ if pkt.Route.RemoteAddress != test.expectedAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ }
+ if pkt.Route.LocalAddress != lladdr1 {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ }
+}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
}
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 3705f56a2..2bd8f4ece 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,9 +16,12 @@
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
"sort"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,10 +29,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // As per RFC 8200 section 4.5:
+ // If insufficient fragments are received to complete reassembly of a packet
+ // within 60 seconds of the reception of the first-arriving fragment of that
+ // packet, reassembly of that packet must be abandoned.
+ //
+ // Linux also uses 60 seconds for reassembly timeout:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
+ reassembleTimeout = 60 * time.Second
+
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -40,6 +53,9 @@ const (
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
@@ -50,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
@@ -348,21 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
@@ -376,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets. The transport
+// header protocol number is required to avoid parsing the IPv6 extension
+// headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ if fragMTU < pkt.TransportHeader().View().Size() {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ break
+ }
+ }
+
+ return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -402,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -423,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
+
r.Stats().IP.PacketsSent.Increment()
return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -441,6 +500,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
+ if e.packetMustBeFragmented(pb, gso) {
+ current := pb
+ _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(current, fragPkt)
+ current = current.Next()
+ return nil
+ })
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // The fragmented packet can be released. The rest of the packets can be
+ // processed.
+ pkts.Remove(pb)
+ pb = current
+ }
}
// iptables filtering. All packets that reach here are locally
@@ -451,8 +527,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -466,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -475,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -536,7 +616,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
- for firstHeader := true; ; firstHeader = false {
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -550,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -576,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -593,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -737,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -767,8 +878,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ r.Stats().IP.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -777,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -922,6 +1055,13 @@ func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpo
return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
}
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
e.mu.Lock()
@@ -937,18 +1077,18 @@ func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, all
return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
}
-// AcquirePrimaryAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.acquirePrimaryAddressRLocked(remoteAddr, allowExpired)
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
-// acquirePrimaryAddressRLocked is like AcquirePrimaryAddress but with locking
-// requirements.
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
//
// Precondition: e.mu must be read locked.
-func (e *endpoint) acquirePrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
// addrCandidate is a candidate for Source Address Selection, as per
// RFC 6724 section 5.
type addrCandidate struct {
@@ -957,7 +1097,7 @@ func (e *endpoint) acquirePrimaryAddressRLocked(remoteAddr tcpip.Address, allowE
}
if len(remoteAddr) == 0 {
- return e.mu.addressableEndpointState.AcquirePrimaryAddress(remoteAddr, allowExpired)
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
// Create a candidate set of available addresses we can potentially use as a
@@ -1090,6 +1230,9 @@ type protocol struct {
eps map[*endpoint]struct{}
}
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -1150,7 +1293,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
@@ -1311,10 +1453,15 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
ndpDisp: opts.NDPDisp,
ndpConfigs: opts.NDPConfigs,
@@ -1332,3 +1479,73 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
+
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // MTU because they should only be transmitted once.
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ mtu -= header.IPv6FragmentHeaderSize
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 94344057e..bee18d1a8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,17 +15,21 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
"math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -53,8 +57,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
t.Helper()
// Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
@@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ result := reassembledPayload.ToView()
+ if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
@@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
protocolFactory stack.TransportProtocolFactory
@@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -336,9 +419,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -348,12 +452,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -364,39 +474,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
@@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -683,7 +972,6 @@ type fragmentData struct {
func TestReceiveIPv6Fragments(t *testing.T) {
const (
- nicID = 1
udpPayload1Length = 256
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
@@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
-
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
}
return rt
}
@@ -1922,3 +2210,293 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
}
}
}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+type fragmentationTestCase struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ extraHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ expectedFrags int
+}
+
+var fragmentationTests = []fragmentationTestCase{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso nil",
+ mtu: 1280,
+ gso: nil,
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header and prependable bytes",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 20,
+ extraHdrLen: header.IPv6MinimumSize + 66,
+ payloadSize: 1500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 296, more: false},
+ },
+ },
+}
+
+func TestFragmentation(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if len(ep.WrittenPackets) > 0 {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on packet with MTU smaller than transport header",
+ mtu: 1280,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 84c082852..40da011f8 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.ep.linkEP.LinkAddress()
+ linkAddr := ndp.ep.nic.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
@@ -1891,7 +1891,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquirePrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
if addressEndpoint == nil {
// Incase this ends up creating a new temporary address, we need to hold
// onto the endpoint until a route is obtained. If we decrement the
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 25464a03a..9033a9ed5 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naDst tcpip.Address
}{
{
- name: "Unspecified source to multicast destination",
+ name: "Unspecified source to solicited-node multicast destination",
nsOpts: nil,
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
@@ -437,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
+ nsInvalid: true,
},
{
name: "Unspecified source with source ll option to unicast destination",
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index c9e57dc0d..d0ffc299a 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -8,6 +8,7 @@ go_library(
"testutil.go",
],
visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 2eaeab779..eba97334e 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,7 +56,6 @@ go_library(
srcs = [
"addressable_endpoint_state.go",
"conntrack.go",
- "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -73,6 +72,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -123,7 +123,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -139,7 +138,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarder_test.go",
+ "forwarding_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 270ac4977..4d3acab96 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -412,6 +412,60 @@ func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState)
a.releaseAddressStateLocked(addrState)
}
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
// AcquireAssignedAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
@@ -461,47 +515,34 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
-// AcquirePrimaryAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
defer a.mu.RUnlock()
- var deprecatedEndpoint *addressState
- for _, ep := range a.mu.primary {
- if !ep.IsAssigned(allowExpired) {
- continue
- }
-
- if !ep.Deprecated() {
- if ep.IncRef() {
- // ep is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- a.decAddressRefLocked(deprecatedEndpoint)
- deprecatedEndpoint = nil
- }
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
- return ep
- }
- } else if deprecatedEndpoint == nil && ep.IncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of
- // ep in case a doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, ep's reference count
- // will be decremented.
- deprecatedEndpoint = ep
- }
- }
-
- // a doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if a doesn't have any valid deprecated
- // endpoints either).
- if deprecatedEndpoint == nil {
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
return nil
}
- return deprecatedEndpoint
+
+ return ep
}
// PrimaryAddresses implements AddressableEndpoint.
@@ -638,11 +679,6 @@ type addressState struct {
}
}
-// NetworkEndpoint implements AddressEndpoint.
-func (a *addressState) NetworkEndpoint() NetworkEndpoint {
- return a.addressableEndpointState.networkEndpoint
-}
-
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index de4e0d7b1..26787d0a3 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -24,8 +24,13 @@ import (
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
var s stack.AddressableEndpointState
- s.Init(&fakeNetworkEndpoint{})
+ s.Init(&ep)
addr := tcpip.AddressWithPrefix{
Address: "\x01",
@@ -43,7 +48,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
{
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep == nil {
- t.Fatalf("got s.AcquireAssignedAddress(%s) = nil, want = non-nil", addr.Address)
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
}
ep.DecRef()
}
@@ -63,7 +68,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep != nil {
ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
if s.IsInGroup(group) {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 836682ea0..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4e4b00a92..cf042309e 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -48,10 +48,9 @@ const (
type fwdTestNetworkEndpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
+ nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
- ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
@@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool {
func (*fwdTestNetworkEndpoint) Disable() {}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index faa503b00..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, tcpip.ErrNotConnected
}
- return it.connections.originalDst(epID)
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 8581dd5e8..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -34,7 +34,7 @@ func (at *AcceptTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -52,7 +52,7 @@ func (dt *DropTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -76,7 +76,7 @@ func (et *ErrorTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -99,7 +99,7 @@ func (uc *UserChainTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -118,7 +118,7 @@ func (rt *ReturnTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -128,26 +128,14 @@ func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.
const RedirectTargetName = "REDIRECT"
// RedirectTarget redirects the packet by modifying the destination port/IP.
-// Min and Max values for IP and Ports in the struct indicate the range of
-// values which can be used to redirect.
+// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
+// them.
type RedirectTarget struct {
- // TODO(gvisor.dev/issue/170): Other flags need to be added after
- // we support them.
- // RangeProtoSpecified flag indicates single port is specified to
- // redirect.
- RangeProtoSpecified bool
+ // Addr indicates address used to redirect.
+ Addr tcpip.Address
- // MinIP indicates address used to redirect.
- MinIP tcpip.Address
-
- // MaxIP indicates address used to redirect.
- MaxIP tcpip.Address
-
- // MinPort indicates port used to redirect.
- MinPort uint16
-
- // MaxPort indicates port used to redirect.
- MaxPort uint16
+ // Port indicates port used to redirect.
+ Port uint16
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
@@ -165,7 +153,7 @@ func (rt *RedirectTarget) ID() TargetID {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -176,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
- rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
- rt.MinIP = address
- rt.MaxIP = address
+ rt.Addr = address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.MinPort)
+ udpHeader.SetDestinationPort(rt.Port)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -212,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.MinIP)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 27e1feec0..4df288798 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
- case Reachable, Static:
+ case Stale:
+ entry.handlePacketQueuedLocked()
+ fallthrough
+ case Reachable, Static, Delay, Probe:
+ // As per RFC 4861 section 7.3.3:
+ // "Neighbor Unreachability Detection operates in parallel with the sending
+ // of packets to a neighbor. While reasserting a neighbor's reachability,
+ // a node continues sending packets to that neighbor using the cached
+ // link-layer address."
return entry.neigh, nil, nil
-
- case Unknown, Incomplete, Stale, Delay, Probe:
+ case Unknown, Incomplete:
entry.addWakerLocked(w)
if entry.done == nil {
@@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.handlePacketQueuedLocked()
return entry.neigh, entry.done, tcpip.ErrWouldBlock
-
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index a0b7da5cd..fcd54ed83 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) {
}
// Verify the entry exists
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
- }
- if t.Failed() {
- t.FailNow()
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ {
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
}
// Notify of a link address change
@@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) {
IsRouter: false,
})
- // Requesting the entry again should start address resolution
+ // Requesting the entry again should start neighbor reachability confirmation.
+ //
+ // Verify the entry's new link address and the new state.
{
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
- select {
- case <-doneCh:
- default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
- // Verify the entry's new link address
+ // Verify that the neighbor is now reachable.
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- want = NeighborEntry{
+ want := NeighborEntry{
Addr: entry.Addr,
LocalAddr: entry.LocalAddr,
LinkAddr: updatedLinkAddr,
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 9a72bec79..4d69a4de1 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index a265fff0a..e79abebca 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- id: entryTestNICID,
- linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+
+ id: entryTestNICID,
stack: &Stack{
clock: clock,
nudDisp: &disp,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 212c6edae..8828cc5fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -32,14 +32,18 @@ var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
+ LinkEndpoint
+
stack *Stack
id tcpip.NICID
name string
- linkEP LinkEndpoint
context NICContext
- stats NICStats
- neigh *neighborCache
+ stats NICStats
+ neigh *neighborCache
+
+ // The network endpoints themselves may be modified by calling the interface's
+ // methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -88,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
+ LinkEndpoint: ep,
+
stack: stack,
id: id,
name: name,
- linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
@@ -127,11 +132,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.linkEP.Attach(nic)
+ nic.LinkEndpoint.Attach(nic)
return nic
}
+func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+ return n.networkEndpoints[proto]
+}
+
// Enabled implements NetworkInterface.
func (n *NIC) Enabled() bool {
return atomic.LoadUint32(&n.enabled) == 1
@@ -211,10 +220,9 @@ func (n *NIC) remove() *tcpip.Error {
for _, ep := range n.networkEndpoints {
ep.Close()
}
- n.networkEndpoints = nil
// Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.LinkEndpoint.Attach(nil)
return nil
}
@@ -234,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool {
// IsLoopback implements NetworkInterface.
func (n *NIC) IsLoopback() bool {
- return n.linkEP.Capabilities()&CapabilityLoopback != 0
+ return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
+}
+
+// WritePacket implements NetworkLinkEndpoint.
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 5.2 (for IPv6):
+ // Once the IP address of the next-hop node is known, the sender
+ // examines the Neighbor Cache for link-layer information about that
+ // neighbor. If no entry exists, the sender creates one, sets its state
+ // to INCOMPLETE, initiates Address Resolution, and then queues the data
+ // packet pending completion of address resolution.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ r := r.Clone()
+ n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ return nil
+ }
+ return err
+ }
+
+ return n.writePacket(r, gso, protocol, pkt)
+}
+
+func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ return err
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
+
+// WritePackets implements NetworkLinkEndpoint.
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
+ // is being peformed like WritePacket.
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
+ }
+
+ n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return writtenPackets, err
}
// setSpoofing enables or disables address spoofing.
@@ -244,22 +309,19 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquirePrimaryAddress(remoteAddr, n.mu.spoofing)
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -360,13 +422,12 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- addressEndpoint := n.primaryEndpoint(proto, "")
- if addressEndpoint == nil {
+ ep, ok := n.networkEndpoints[proto]
+ if !ok {
return tcpip.AddressWithPrefix{}
}
- addr := addressEndpoint.AddressWithPrefix()
- addressEndpoint.DecRef()
- return addr
+
+ return ep.MainAddress()
}
// removeAddress removes an address from n.
@@ -487,9 +548,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
- addressEndpoint.DecRef()
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -523,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.linkEP.LinkAddress()
+ local = n.LinkEndpoint.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -603,11 +664,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n := r.nic
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -618,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
- // forwarder will release route.
- return
- }
+
+ // pkt may have set its header and may not have enough headroom for
+ // link-layer header for the other link to prepend. Here we create a new
+ // packet to forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- r.Release()
- return
}
- // The link-address resolution finished immediately.
- n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -656,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.linkEP.AddHeader(local, remote, protocol, p)
+ n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- // pkt may have set its header and may not have enough headroom for link-layer
- // header for the other link to prepend. Here we create a new packet to
- // forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // WritePacket takes ownership of fwdPkt, calculate numBytes first.
- numBytes := fwdPkt.Size()
-
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- // TODO(gvisor.dev/issue/4365): Let the caller know that the transport
- // protocol is unrecognized.
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return TransportPacketHandled
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -796,11 +832,6 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint implements NetworkInterface.
-func (n *NIC) LinkEndpoint() LinkEndpoint {
- return n.linkEP
-}
-
// nudConfigs gets the NUD configurations for n.
func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index fdd49b77f..97a96af62 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
type testIPv6Endpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
- linkEP LinkEndpoint
+ nic NetworkInterface
protocol *testIPv6Protocol
invalidatedRtr tcpip.Address
@@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.linkEP.MTU() - header.IPv6MinimumSize
+ return e.nic.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint implements NetworkProtocol.NewEndpoint.
func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
- nicID: nic.ID(),
- linkEP: nic.LinkEndpoint(),
+ nic: nic,
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index a7d9d59fa..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
@@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
return newPk
}
+// Network returns the network header as a header.Network.
+//
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go
index 3eff141e6..f838eda8d 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
- nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-type forwardQueue struct {
+// packetsPendingLinkResolution is a queue of packets pending link resolution.
+//
+// Once link resolution completes successfully, the packets will be written.
+type packetsPendingLinkResolution struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]*pendingPacket
+ packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func newForwardQueue() *forwardQueue {
- return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+func (f *packetsPendingLinkResolution) init() {
+ f.Lock()
+ defer f.Unlock()
+ f.packets = make(map[<-chan struct{}][]pendingPacket)
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- shouldWait := false
-
+func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
f.Lock()
+ defer f.Unlock()
+
packets, ok := f.packets[ch]
- if !ok {
- shouldWait = true
- }
- for len(packets) == maxPendingPacketsPerResolution {
+ if len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
+ packets[0] = pendingPacket{}
packets = packets[1:]
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
+
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
- f.packets[ch] = append(packets, &pendingPacket{
- nic: n,
+
+ f.packets[ch] = append(packets, pendingPacket{
route: r,
- proto: protocol,
+ proto: proto,
pkt: pkt,
})
- f.Unlock()
- if !shouldWait {
+ if ok {
return
}
// Wait for the link-address resolution to complete.
- // Start a goroutine with a forwarding-cancel channel so that we can
- // limit the maximum number of goroutines running concurrently.
- cancel := f.newCancelChannel()
+ cancel := f.newCancelChannelLocked()
go func() {
cancelled := false
select {
@@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
}
f.Lock()
- packets := f.packets[ch]
+ packets, ok := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
+ if !ok {
+ panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
+ }
+
for _, p := range packets {
if cancelled {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
@@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *forwardQueue) newCancelChannel() chan struct{} {
- f.Lock()
- defer f.Unlock()
-
+func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
+ f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 567e1904e..defb9129b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -208,6 +208,10 @@ const (
// transport layer and callers need not take any further action.
TransportPacketHandled TransportPacketDisposition = iota
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
// TransportPacketDestinationPortUnreachable indicates that there weren't any
// listeners interested in the packet and the transport protocol has no means
// to notify the sender.
@@ -322,10 +326,6 @@ const (
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
- // NetworkEndpoint returns the NetworkEndpoint the receiver is associated
- // with.
- NetworkEndpoint() NetworkEndpoint
-
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
@@ -435,7 +435,10 @@ type AddressableEndpoint interface {
// permanent address.
RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
- // AcquireAssignedAddress returns an AddressEndpoint for the passed address
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
// that is considered bound to the endpoint, optionally creating a temporary
// endpoint if requested and no existing address exists.
//
@@ -444,15 +447,15 @@ type AddressableEndpoint interface {
// Returns nil if the specified address is not local to this endpoint.
AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
- // AcquirePrimaryAddress returns a primary endpoint to use when communicating
- // with the passed remote address.
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
//
// If allowExpired is true, expired addresses may be returned.
//
// The returned endpoint's reference count is incremented.
//
- // Returns nil if a primary endpoint is not available.
- AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
// PrimaryAddresses returns the primary addresses.
PrimaryAddresses() []tcpip.AddressWithPrefix
@@ -472,6 +475,8 @@ type NDPEndpoint interface {
// NetworkInterface is a network interface.
type NetworkInterface interface {
+ NetworkLinkEndpoint
+
// ID returns the interface's ID.
ID() tcpip.NICID
@@ -485,9 +490,6 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
-
- // LinkEndpoint returns the link endpoint backing the interface.
- LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -660,22 +662,15 @@ const (
CapabilitySoftwareGSO
)
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
+// NetworkLinkEndpoint is a data-link layer that supports sending network
+// layer packets.
+type NetworkLinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -683,7 +678,7 @@ type LinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // link endpoint.
+ // endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -703,6 +698,19 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+}
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ NetworkLinkEndpoint
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 5ade3c832..25f80c1f8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: linkEP.LinkAddress(),
+ LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = nic.stack
+ r.linkCache = r.nic.stack
}
}
@@ -100,7 +99,7 @@ func (r *Route) NICID() tcpip.NICID {
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength()
+ return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
@@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint.Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok {
+ if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
-// ResolveWith immediately resolves a route with the specified remote link
-// address.
-func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
-}
-
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -208,17 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
return tcpip.ErrInvalidEndpointState
}
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- }
- return err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
@@ -228,22 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return n, err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -253,32 +221,17 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Data.Size()
-
- if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.addressEndpoint.NetworkEndpoint().DefaultTTL()
+ return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.addressEndpoint.NetworkEndpoint().MTU()
-}
-
-// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
-// network endpoint.
-func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber()
+ return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 57d8e79e0..3a07577c8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -436,9 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // forwarder holds the packets that wait for their link-address resolutions
- // to complete, and forwards them when each resolution is done.
- forwarder *forwardQueue
+ // linkResQueue holds packets that are waiting for link resolution to
+ // complete.
+ linkResQueue packetsPendingLinkResolution
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -550,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(e.ID.LocalAddress) {
+ switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == e.NetProto:
- case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ case netProto == t.NetProto:
+ case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -640,7 +640,6 @@ func New(opts Options) *Stack {
useNeighborCache: opts.UseNeighborCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
@@ -653,6 +652,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetNICByName gets the NIC specified by name.
-func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+// GetLinkEndpointByName gets the link endpoint specified by name.
+func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic, true
+ return nic.LinkEndpoint
}
}
- return nil, false
+ return nil
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.linkEP.LinkAddress(),
+ LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.linkEP.MTU(),
+ MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.linkEP.ARPHardwareType(),
+ ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
}
return nics
@@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1539,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.linkEP.Wait()
+ n.LinkEndpoint.Wait()
}
}
@@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.linkEP.LinkAddress(),
+ SrcAddr: nic.LinkEndpoint.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
return err
}
@@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
return err
}
@@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco
return nil, tcpip.ErrUnknownNICID
}
- return nic.networkEndpoints[proto], nil
+ return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
if addressEndpoint == nil {
continue
}
-
- ep := addressEndpoint.NetworkEndpoint()
addressEndpoint.DecRef()
- return ep, nil
+ return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index fda22c550..38994cca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"math"
- "net"
"sort"
"testing"
"time"
@@ -35,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -71,24 +69,38 @@ const (
type fakeNetworkEndpoint struct {
stack.AddressableEndpointState
- nicID tcpip.NICID
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
+ nic stack.NetworkInterface
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
}
-func (*fakeNetworkEndpoint) Enable() *tcpip.Error {
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
return nil
}
-func (*fakeNetworkEndpoint) Enabled() bool {
- return true
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
}
-func (*fakeNetworkEndpoint) Disable() {}
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
+}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -120,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+ return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -149,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -201,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
@@ -2091,7 +2102,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -3487,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
-func TestResolveWith(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
- }
- defer r.Release()
-
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
- }
-
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
- }
-}
-
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
@@ -3620,3 +3585,43 @@ func TestGetNetworkEndpoint(t *testing.T) {
})
}
}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 06c7a3cd3..a4f141253 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -6,6 +6,8 @@ go_test(
name = "integration_test",
size = "small",
srcs = [
+ "forward_test.go",
+ "link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
],
@@ -15,6 +17,8 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
new file mode 100644
index 000000000..ffd38ee1a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -0,0 +1,378 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestForwarding(t *testing.T) {
+ const (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+
+ listenPort = 8080
+ )
+
+ host1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 host2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack)
+ defer epsAndAddrs.serverEP.Close()
+ defer epsAndAddrs.clientEP.Close()
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) {
+ t.Helper()
+
+ dataPayload := tcpip.SlicePayload(data)
+ wOpts := tcpip.WriteOptions{To: to}
+ n, ch, err := ep.Write(dataPayload, wOpts)
+ if err == tcpip.ErrNoLinkAddress {
+ // Wait for link resolution to complete.
+ <-ch
+
+ n, _, err = ep.Write(dataPayload, wOpts)
+ } else if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+
+ if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+ }
+
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data, &serverAddr)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress {
+ t.Helper()
+
+ // Wait for the endpoint to be readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+
+ if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != expectedFrom {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ return addr
+ }
+
+ addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
+ // Unspecify the NIC since NIC IDs are meaningless across stacks.
+ addr.NIC = 0
+
+ data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ write(epsAndAddrs.serverEP, data, &addr)
+ addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr)
+ if addr.Port != listenPort {
+ t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
new file mode 100644
index 000000000..bf3a6f6ee
--- /dev/null
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+)
+
+// TestPing tests that two hosts can ping eachother when link resolution is
+// enabled.
+func TestPing(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ }{
+ {
+ name: "IPv4 Ping",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ {
+ name: "IPv6 Ping",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, waiterCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ // The first write should trigger link resolution.
+ icmpBuf := test.icmpBuf(t)
+ wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
+ if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress)
+ } else {
+ // Wait for link resolution to complete.
+ <-ch
+ }
+ if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(icmpBuf)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+
+ // Wait for the endpoint to be readable.
+ <-waiterCH
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.remoteAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 72d86b5ab..4f2ca7f54 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
if src != expectedSrc {
t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6891fd245..0aaef495d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
@@ -1219,12 +1219,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
- // Increase counter if after processing the segment we would potentially
- // advertise a zero window.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
-
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7ad894840..3bcd3923a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -248,6 +248,11 @@ type ReceiveErrors struct {
// ZeroRcvWindowState is the number of times we advertised
// a zero receive window when rcvList is full.
ZeroRcvWindowState tcpip.StatCounter
+
+ // WantZeroWindow is the number of times we wanted to advertise a
+ // zero receive window but couldn't because it would have caused
+ // the receive window's right edge to shrink.
+ WantZeroRcvWindow tcpip.StatCounter
}
// SendErrors collect segment send errors within the transport layer.
@@ -1162,7 +1167,7 @@ func (e *endpoint) cleanupLocked() {
// wndFromSpace returns the window that we can advertise based on the available
// receive buffer space.
func wndFromSpace(space int) int {
- return space / (1 << rcvAdvWndScale)
+ return space >> rcvAdvWndScale
}
// initialReceiveWindow returns the initial receive window to advertise in the
@@ -1518,6 +1523,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
+// selectWindowLocked returns the new window without checking for shrinking or scaling
+// applied.
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
+ maxWindow := wndFromSpace(e.rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvBufUsed
+
+ // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
+ // cases where we receive a lot of small segments the segment overhead is a
+ // lot higher and we can run out socket buffer space before we can fill the
+ // previous window we advertised. In cases where we receive MSS sized or close
+ // MSS sized segments we will probably run out of window space before we
+ // exhaust receive buffer.
+ newWnd := wndFromAvailable
+ if newWnd > wndFromUsedBytes {
+ newWnd = wndFromUsedBytes
+ }
+ if newWnd < 0 {
+ newWnd = 0
+ }
+ return seqnum.Size(newWnd)
+}
+
+// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+func (e *endpoint) selectWindow() (wnd seqnum.Size) {
+ e.rcvListMu.Lock()
+ wnd = e.selectWindowLocked()
+ e.rcvListMu.Unlock()
+ return wnd
+}
+
// windowCrossedACKThresholdLocked checks if the receive window to be announced
// would be under aMSS or under the window derived from half receive buffer,
// whichever smaller. This is useful as a receive side silly window syndrome
@@ -1534,7 +1571,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
+ newAvail := int(e.selectWindowLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -2099,7 +2136,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID)
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
e.UnlockUser()
if err != nil {
return err
@@ -3013,6 +3050,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
EndSequence: rc.endSequence,
FACK: rc.fack,
RTT: rc.rtt,
+ Reord: rc.reorderSeen,
}
return s
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index d969ca23a..d312b1b8b 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -29,26 +29,36 @@ import (
//
// +stateify savable
type rackControl struct {
- // xmitTime is the latest transmission timestamp of rackControl.seg.
- xmitTime time.Time `state:".(unixTime)"`
-
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
+ // dsack indicates if the connection has seen a DSACK.
+ dsack bool
+
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
+ // minRTT is the estimated minimum RTT of the connection.
+ minRTT time.Duration
+
// rtt is the RTT of the most recently delivered packet on the
// connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
rtt time.Duration
+
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
}
-// Update will update the RACK related fields when an ACK has been received.
+// update will update the RACK related fields when an ACK has been received.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
-func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
rtt := time.Now().Sub(seg.xmitTime)
// If the ACK is for a retransmitted packet, do not update if it is a
@@ -65,12 +75,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
return
}
}
- if rtt < srtt {
+ if rtt < rc.minRTT {
return
}
}
rc.rtt = rtt
+
+ // The sender can either track a simple global minimum of all RTT
+ // measurements from the connection, or a windowed min-filtered value
+ // of recent RTT measurements. This implementation keeps track of the
+ // simple global minimum of all RTTs for the connection.
+ if rtt < rc.minRTT || rc.minRTT == 0 {
+ rc.minRTT = rtt
+ }
+
// Update rc.xmitTime and rc.endSequence to the transmit time and
// ending sequence number of the packet which has been acknowledged
// most recently.
@@ -80,3 +99,26 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
rc.endSequence = endSeq
}
}
+
+// detectReorder detects if packet reordering has been observed.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 3: Detect data segment reordering.
+// To detect reordering, the sender looks for original data segments being
+// delivered out of order. To detect such cases, the sender tracks the
+// highest sequence selectively or cumulatively acknowledged in the RACK.fack
+// variable. The name "fack" stands for the most "Forward ACK" (this term is
+// adopted from [FACK]). If a never retransmitted segment that's below
+// RACK.fack is (selectively or cumulatively) acknowledged, it has been
+// delivered out of order. The sender sets RACK.reord to TRUE if such segment
+// is identified.
+func (rc *rackControl) detectReorder(seg *segment) {
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.fack.LessThan(endSeq) {
+ rc.fack = endSeq
+ return
+ }
+
+ if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
+ rc.reorderSeen = true
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 48bf196d8..8e0b7c843 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -43,6 +43,9 @@ type receiver struct {
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
+ // rcvWUP is the rcvNxt value at the last window update sent.
+ rcvWUP seqnum.Value
+
rcvWndScale uint8
closed bool
@@ -64,6 +67,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
+ rcvWUP: irs + 1,
rcvWndScale: rcvWndScale,
lastRcvdAckTime: time.Now(),
}
@@ -84,34 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
}
+// currentWindow returns the available space in the window that was advertised
+// last to our peer.
+func (r *receiver) currentWindow() (curWnd seqnum.Size) {
+ endOfWnd := r.rcvWUP.Add(r.rcvWnd)
+ if endOfWnd.LessThan(r.rcvNxt) {
+ // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ // This can happen because we accept a large segment completely even if
+ // accepting it causes it to partially exceed the advertised window.
+ return 0
+ }
+ return r.rcvNxt.Size(endOfWnd)
+}
+
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- avail := wndFromSpace(r.ep.receiveBufferAvailable())
- if avail == 0 {
- // We have no space available to accept any data, move to zero window
- // state.
- r.rcvWnd = 0
- return r.rcvNxt, 0
- }
-
- acc := r.rcvNxt.Add(seqnum.Size(avail))
- newWnd := r.rcvNxt.Size(acc)
- curWnd := r.rcvNxt.Size(r.rcvAcc)
-
+ newWnd := r.ep.selectWindow()
+ curWnd := r.currentWindow()
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
// would end up dropping bytes that might already be in flight.
- if newWnd > curWnd {
- r.rcvAcc = r.rcvNxt.Add(newWnd)
+ // ==================================================== sequence space.
+ // ^ ^ ^ ^
+ // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // <=====curWnd ===>
+ // <========= newWnd > curWnd ========= >
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ // If the new window moves the right edge, then update rcvAcc.
+ r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
+ if newWnd == 0 {
+ // newWnd is zero but we can't advertise a zero as it would cause window
+ // to shrink so just increment a metric to record this event.
+ r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment()
+ }
newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
- return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
+ r.rcvWUP = r.rcvNxt
+ scaledWnd := r.rcvWnd >> r.rcvWndScale
+ if scaledWnd == 0 {
+ // Increment a metric if we are advertising an actual zero window.
+ r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+ return r.rcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 13acaf753..1f9c5cf50 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -71,6 +71,9 @@ type segment struct {
// xmitTime is the last transmit time of this segment.
xmitTime time.Time `state:".(unixTime)"`
xmitCount uint32
+
+ // acked indicates if the segment has already been SACKed.
+ acked bool
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index c55589c45..6fa8d63cd 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "sort"
"sync/atomic"
"time"
@@ -263,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
+ rc: rackControl{
+ fack: iss,
+ },
gso: ep.gso != nil,
}
@@ -1274,6 +1278,39 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
return true
}
+// Iterate the writeList and update RACK for each segment which is newly acked
+// either cumulatively or selectively. Loop through the segments which are
+// sacked, and update the RACK related variables and check for reordering.
+//
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// steps 2 and 3.
+func (s *sender) walkSACK(rcvdSeg *segment) {
+ // Sort the SACK blocks. The first block is the most recent unacked
+ // block. The following blocks can be in arbitrary order.
+ sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sort.Slice(sackBlocks, func(i, j int) bool {
+ return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
+ })
+
+ seg := s.writeList.Front()
+ for _, sb := range sackBlocks {
+ // This check excludes DSACK blocks.
+ if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
+ continue
+ }
+
+ for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
+ if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
+ seg.acked = true
+ }
+ seg = seg.Next()
+ }
+ }
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1308,6 +1345,21 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
rcvdSeg.hasNewSACKInfo = true
}
}
+
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08
+ // section-7.2
+ // * Step 2: Update RACK stats.
+ // If the ACK is not ignored as invalid, update the RACK.rtt
+ // to be the RTT sample calculated using this ACK, and
+ // continue. If this ACK or SACK was for the most recently
+ // sent packet, then record the RACK.xmit_ts timestamp and
+ // RACK.end_seq sequence implied by this ACK.
+ // * Step 3: Detect packet reordering.
+ // If the ACK selectively or cumulatively acknowledges an
+ // unacknowledged and also never retransmitted sequence below
+ // RACK.fack, then the corresponding packet has been
+ // reordered and RACK.reord is set to TRUE.
+ s.walkSACK(rcvdSeg)
s.SetPipe()
}
@@ -1365,9 +1417,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
- s.rtt.Lock()
- srtt := s.rtt.srtt
- s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1388,13 +1437,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted {
- s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ if s.ep.sackPermitted && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
}
s.writeList.Remove(seg)
- // if SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then Only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index e03f101e8..d3f92b48c 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -21,17 +21,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
+const (
+ maxPayload = 10
+ tsOptionSize = 12
+ maxTCPOptionSize = 40
+)
+
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
- const maxPayload = 10
- const tsOptionSize = 12
- const maxTCPOptionSize = 40
-
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
defer c.Cleanup()
@@ -49,7 +52,7 @@ func TestRACKUpdate(t *testing.T) {
}
if state.Sender.RACKState.RTT == 0 {
- t.Fatalf("RACK RTT failed to update when an ACK is received")
+ t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
}
})
setStackSACKPermitted(t, c, true)
@@ -69,6 +72,66 @@ func TestRACKUpdate(t *testing.T) {
bytesRead := 0
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- c.SendAck(790, bytesRead)
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
time.Sleep(200 * time.Millisecond)
}
+
+// TestRACKDetectReorder tests that RACK detects packet reordering.
+func TestRACKDetectReorder(t *testing.T) {
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ const ackNum = 2
+
+ var n int
+ ch := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ gotSeq := state.Sender.RACKState.FACK
+ wantSeq := state.Sender.SndNxt
+ // FACK should be updated to the highest ending sequence number of the
+ // segment acknowledged most recently.
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ n++
+ if n < ackNum {
+ if state.Sender.RACKState.Reord {
+ t.Fatalf("RACK reorder detected when there is no reordering")
+ }
+ return
+ }
+
+ if state.Sender.RACKState.Reord == false {
+ t.Fatalf("RACK reorder detection failed")
+ }
+ close(ch)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ data := buffer.NewView(ackNum * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < ackNum; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ start := c.IRS.Add(maxPayload + 1)
+ end := start.Add(maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-ch
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5b504d0d1..a7149efd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6264,14 +6264,27 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
+
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- pkt := c.GetPacket()
- curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // Read loop above could generate an ACK if the window had dropped to
+ // zero and then read had opened it up.
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
+ }
+ lastACK = pkt
+ }
+ curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
// If thew new current window is close maxReceiveBufferSize then terminate
// the loop. This can happen before all iterations are done due to timing
// differences when running the test.
@@ -7328,7 +7341,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf * 2
+ remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
@@ -7343,7 +7356,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index faf51ef95..4d7847142 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -68,9 +68,9 @@ const (
// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // TestInitialSequenceNumber is the initial sequence number sent in packets that
// are sent in response to a SYN or in the initial SYN sent to the stack.
- testInitialSequenceNumber = 789
+ TestInitialSequenceNumber = 789
)
// StackAddrWithPrefix is StackAddr with its associated prefix length.
@@ -505,7 +505,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -532,7 +532,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -912,7 +912,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -1084,7 +1084,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
offset += paddingToAdd
// Send a SYN request.
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: StackPort,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4e14a5fc5..b4604ba35 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1432,7 +1432,9 @@ func TestNoChecksum(t *testing.T) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -1450,10 +1452,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1780,16 +1778,26 @@ func TestV4UnknownDestination(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4DstUnreachable),
checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ // We need to compare the included data part of the UDP packet that is in
+ // the ICMP packet with the matching original data.
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ // To work out the data size we need to simulate what the sender would
+ // have done. The wanted size is the total available minus the sum of
+ // the headers in the UDP AND ICMP packets, given that we know the test
+ // had only a minimal IP header but the ICMP sender will have allowed
+ // for a maximally sized packet header.
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+
}
- // In case of large payloads the IP packet may be truncated. Update
+ // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // Add back the two headers within the payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram.Payload()), wantLen; got != want {
@@ -2015,7 +2023,8 @@ func TestPayloadModifiedV4(t *testing.T) {
payload := newPayload()
h := unicastV4.header4Tuple(incoming)
buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
@@ -2045,7 +2054,8 @@ func TestPayloadModifiedV6(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 06fb823f6..49ab87c58 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -270,7 +270,7 @@ func RandomID(prefix string) string {
// same name, sometimes between test runs the socket does not get cleaned up
// quickly enough, causing container creation to fail.
func RandomContainerID() string {
- return RandomID("test-container-")
+ return RandomID("test-container")
}
// Copy copies file from src to dst.