summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD2
-rw-r--r--pkg/abi/linux/ioctl.go3
-rw-r--r--pkg/abi/linux/sem.go12
-rw-r--r--pkg/abi/linux/sem_amd64.go33
-rw-r--r--pkg/abi/linux/sem_arm64.go31
-rw-r--r--pkg/bpf/decoder.go2
-rw-r--r--pkg/context/context.go24
-rw-r--r--pkg/merkletree/BUILD10
-rw-r--r--pkg/merkletree/merkletree.go118
-rw-r--r--pkg/merkletree/merkletree_test.go407
-rw-r--r--pkg/refs/refcounter.go10
-rw-r--r--pkg/refsvfs2/BUILD (renamed from pkg/refs_vfs2/BUILD)19
-rw-r--r--pkg/refsvfs2/refs.go (renamed from pkg/refs_vfs2/refs.go)4
-rw-r--r--pkg/refsvfs2/refs_map.go131
-rw-r--r--pkg/refsvfs2/refs_template.go (renamed from pkg/refs_vfs2/refs_template.go)76
-rw-r--r--pkg/sentry/control/state.go2
-rw-r--r--pkg/sentry/devices/tundev/tundev.go4
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go7
-rw-r--r--pkg/sentry/fs/gofer/path.go3
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go83
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go6
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD3
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go17
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go2
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/BUILD5
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/save_restore.go23
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD3
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go20
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD3
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go52
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go482
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go22
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go329
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go108
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go12
-rw-r--r--pkg/sentry/fsimpl/host/BUILD7
-rw-r--r--pkg/sentry/fsimpl/host/control.go2
-rw-r--r--pkg/sentry/fsimpl/host/host.go255
-rw-r--r--pkg/sentry/fsimpl/host/save_restore.go70
-rw-r--r--pkg/sentry/fsimpl/host/util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD40
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go93
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go112
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go264
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go62
-rw-r--r--pkg/sentry/fsimpl/kernfs/mmap_util.go (renamed from pkg/sentry/fsimpl/host/mmap.go)83
-rw-r--r--pkg/sentry/fsimpl/kernfs/save_restore.go36
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go11
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD3
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go146
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go242
-rw-r--r--pkg/sentry/fsimpl/overlay/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go2
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD11
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go36
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task.go7
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go16
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go173
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go34
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go30
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go8
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go116
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go1
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go4
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD3
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go2
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go70
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/save_restore.go20
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go11
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD6
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go138
-rw-r--r--pkg/sentry/fsimpl/verity/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go246
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go834
-rw-r--r--pkg/sentry/hostfd/BUILD2
-rw-r--r--pkg/sentry/hostfd/hostfd_linux.go18
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go9
-rw-r--r--pkg/sentry/inet/inet.go6
-rw-r--r--pkg/sentry/inet/test_stack.go21
-rw-r--r--pkg/sentry/kernel/BUILD12
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go10
-rw-r--r--pkg/sentry/kernel/fd_table.go6
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go13
-rw-r--r--pkg/sentry/kernel/fs_context.go24
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go2
-rw-r--r--pkg/sentry/kernel/kernel.go176
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go8
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go50
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go11
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go8
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go75
-rw-r--r--pkg/sentry/kernel/sessions.go32
-rw-r--r--pkg/sentry/kernel/shm/BUILD4
-rw-r--r--pkg/sentry/kernel/task_clone.go2
-rw-r--r--pkg/sentry/kernel/task_usermem.go95
-rw-r--r--pkg/sentry/kernel/vdso.go2
-rw-r--r--pkg/sentry/mm/BUILD5
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go7
-rw-r--r--pkg/sentry/platform/kvm/kvm.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go5
-rw-r--r--pkg/sentry/platform/kvm/machine.go48
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go63
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go33
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go67
-rw-r--r--pkg/sentry/platform/ring0/defs.go3
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go10
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s67
-rw-r--r--pkg/sentry/platform/ring0/kernel.go6
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go5
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go84
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go21
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go15
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go2
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go10
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/hostinet/stack.go21
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go63
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go4
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go52
-rw-r--r--pkg/sentry/socket/netfilter/targets.go217
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go4
-rw-r--r--pkg/sentry/socket/netstack/stack.go85
-rw-r--r--pkg/sentry/socket/unix/BUILD5
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go16
-rw-r--r--pkg/sentry/socket/unix/unix.go1
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/state/BUILD2
-rw-r--r--pkg/sentry/state/state.go10
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go52
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go30
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go23
-rw-r--r--pkg/sentry/vfs/BUILD8
-rw-r--r--pkg/sentry/vfs/epoll.go2
-rw-r--r--pkg/sentry/vfs/file_description.go1
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go11
-rw-r--r--pkg/sentry/vfs/inotify.go2
-rw-r--r--pkg/sentry/vfs/lock.go5
-rw-r--r--pkg/sentry/vfs/mount.go69
-rw-r--r--pkg/sentry/vfs/mount_test.go35
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go21
-rw-r--r--pkg/sentry/vfs/save_restore.go124
-rw-r--r--pkg/sentry/vfs/vfs.go24
-rw-r--r--pkg/shim/runsc/BUILD1
-rw-r--r--pkg/shim/runsc/runsc.go12
-rw-r--r--pkg/state/BUILD14
-rw-r--r--pkg/state/decode.go80
-rw-r--r--pkg/state/decode_unsafe.go57
-rw-r--r--pkg/state/encode.go221
-rw-r--r--pkg/state/pretty/pretty.go41
-rw-r--r--pkg/state/state.go10
-rw-r--r--pkg/state/tests/struct.go35
-rw-r--r--pkg/state/tests/struct_test.go34
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go11
-rw-r--r--pkg/tcpip/checker/checker.go49
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/ipv4.go493
-rw-r--r--pkg/tcpip/header/ipv6.go12
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go36
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go7
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go89
-rw-r--r--pkg/tcpip/network/arp/arp_test.go201
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go58
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go120
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go22
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go23
-rw-r--r--pkg/tcpip/network/ip_test.go115
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go168
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go542
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go848
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go135
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go139
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go275
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go458
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go18
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go12
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go6
-rw-r--r--pkg/tcpip/stack/iptables.go75
-rw-r--r--pkg/tcpip/stack/iptables_targets.go102
-rw-r--r--pkg/tcpip/stack/iptables_types.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go8
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go33
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go531
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go221
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2487
-rw-r--r--pkg/tcpip/stack/nic.go63
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/nud.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer.go60
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go51
-rw-r--r--pkg/tcpip/stack/route.go288
-rw-r--r--pkg/tcpip/stack/stack.go324
-rw-r--r--pkg/tcpip/stack/stack_test.go568
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
-rw-r--r--pkg/tcpip/stack/transport_test.go35
-rw-r--r--pkg/tcpip/tcpip.go24
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go46
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go48
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go2
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go24
-rw-r--r--pkg/tcpip/tests/integration/route_test.go388
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go16
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go38
-rw-r--r--pkg/tcpip/transport/tcp/accept.go93
-rw-r--r--pkg/tcpip/transport/tcp/connect.go30
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go7
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go12
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go31
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment.go47
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go20
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go98
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go50
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go24
-rw-r--r--pkg/tcpip/transport/udp/protocol.go8
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go5
-rw-r--r--pkg/unet/unet_test.go157
-rw-r--r--pkg/waiter/waiter.go2
269 files changed, 13086 insertions, 5079 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 4a26e28de..a0654df2f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -55,6 +55,8 @@ go_library(
"sched.go",
"seccomp.go",
"sem.go",
+ "sem_amd64.go",
+ "sem_arm64.go",
"shm.go",
"signal.go",
"signalfd.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 7df02dd6d..006b5a525 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,6 +121,9 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
+ FS_VERITY_HASH_ALG_SHA256 = 1
+ FS_VERITY_HASH_ALG_SHA512 = 2
+
FS_IOC_ENABLE_VERITY = 1082156677
FS_IOC_MEASURE_VERITY = 3221513862
)
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 487a626cc..1b2f76c0b 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -34,18 +34,6 @@ const (
const SEM_UNDO = 0x1000
-// SemidDS is equivalent to struct semid64_ds.
-//
-// +marshal
-type SemidDS struct {
- SemPerm IPCPerm
- SemOTime TimeT
- SemCTime TimeT
- SemNSems uint64
- unused3 uint64
- unused4 uint64
-}
-
// Sembuf is equivalent to struct sembuf.
//
// +marshal slice:SembufSlice
diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go
new file mode 100644
index 000000000..ab980cb4f
--- /dev/null
+++ b/pkg/abi/linux/sem_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: arch/x86/include/uapi/asm/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ unused1 uint64
+ SemCTime TimeT
+ unused2 uint64
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go
new file mode 100644
index 000000000..521468fb1
--- /dev/null
+++ b/pkg/abi/linux/sem_arm64.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: include/uapi/asm-generic/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ SemCTime TimeT
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
index 069d0395d..6d1e65cb1 100644
--- a/pkg/bpf/decoder.go
+++ b/pkg/bpf/decoder.go
@@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error {
case B:
w.WriteString("1")
default:
- return fmt.Errorf("Invalid BPF LD size: %v", inst)
+ return fmt.Errorf("invalid BPF LD size: %v", inst)
}
return nil
}
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 2613bc752..f3031fc60 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()}
func Background() Context {
return bgContext
}
+
+// WithValue returns a copy of parent in which the value associated with key is
+// val.
+func WithValue(parent Context, key, val interface{}) Context {
+ return &withValue{
+ Context: parent,
+ key: key,
+ val: val,
+ }
+}
+
+type withValue struct {
+ Context
+ key interface{}
+ val interface{}
+}
+
+// Value implements Context.Value.
+func (ctx *withValue) Value(key interface{}) interface{} {
+ if key == ctx.key {
+ return ctx.val
+ }
+ return ctx.Context.Value(key)
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
index a8fcb2e19..501a9ef21 100644
--- a/pkg/merkletree/BUILD
+++ b/pkg/merkletree/BUILD
@@ -6,12 +6,18 @@ go_library(
name = "merkletree",
srcs = ["merkletree.go"],
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
go_test(
name = "merkletree_test",
srcs = ["merkletree_test.go"],
library = ":merkletree",
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index d8227b8bd..e0a9e56c5 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -18,21 +18,32 @@ package merkletree
import (
"bytes"
"crypto/sha256"
+ "crypto/sha512"
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
const (
// sha256DigestSize specifies the digest size of a SHA256 hash.
sha256DigestSize = 32
+ // sha512DigestSize specifies the digest size of a SHA512 hash.
+ sha512DigestSize = 64
)
// DigestSize returns the size (in bytes) of a digest.
-// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
-func DigestSize() int {
- return sha256DigestSize
+// TODO(b/156980949): Allow config SHA384.
+func DigestSize(hashAlgorithm int) int {
+ switch hashAlgorithm {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ return sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ return sha512DigestSize
+ default:
+ return -1
+ }
}
// Layout defines the scale of a Merkle tree.
@@ -51,11 +62,19 @@ type Layout struct {
// InitLayout initializes and returns a new Layout object describing the structure
// of a tree. dataSize specifies the size of input data in bytes.
-func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
+func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) {
layout := Layout{
blockSize: usermem.PageSize,
- // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
- digestSize: sha256DigestSize,
+ }
+
+ // TODO(b/156980949): Allow config SHA384.
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ layout.digestSize = sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ layout.digestSize = sha512DigestSize
+ default:
+ return Layout{}, fmt.Errorf("unexpected hash algorithms")
}
// treeStart is the offset (in bytes) of the first level of the tree in
@@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
}
layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize)
- return layout
+ return layout, nil
}
// hashesPerBlock() returns the number of digests in each block. For example,
@@ -128,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
// meatadata.
type VerityDescriptor struct {
Name string
+ FileSize int64
Mode uint32
UID uint32
GID uint32
@@ -135,16 +155,37 @@ type VerityDescriptor struct {
}
func (d *VerityDescriptor) String() string {
- return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+ return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash)
}
// verify generates a hash from d, and compares it with expected.
-func (d *VerityDescriptor) verify(expected []byte) error {
- h := sha256.Sum256([]byte(d.String()))
+func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error {
+ h, err := hashData([]byte(d.String()), hashAlgorithms)
+ if err != nil {
+ return err
+ }
if !bytes.Equal(h[:], expected) {
return fmt.Errorf("unexpected root hash")
}
return nil
+
+}
+
+// hashData hashes data and returns the result hash based on the hash
+// algorithms.
+func hashData(data []byte, hashAlgorithms int) ([]byte, error) {
+ var digest []byte
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ digestArray := sha256.Sum256(data)
+ digest = digestArray[:]
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ digestArray := sha512.Sum512(data)
+ digest = digestArray[:]
+ default:
+ return nil, fmt.Errorf("unexpected hash algorithms")
+ }
+ return digest, nil
}
// GenerateParams contains the parameters used to generate a Merkle tree.
@@ -161,6 +202,8 @@ type GenerateParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// TreeReader is a reader for the Merkle tree.
TreeReader io.ReaderAt
// TreeWriter is a writer for the Merkle tree.
@@ -176,7 +219,10 @@ type GenerateParams struct {
// Generate returns a hash of a VerityDescriptor, which contains the file
// metadata and the hash from file content.
func Generate(params *GenerateParams) ([]byte, error) {
- layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
+ layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return nil, err
+ }
numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
@@ -218,10 +264,13 @@ func Generate(params *GenerateParams) ([]byte, error) {
return nil, err
}
// Hash the bytes in buf.
- digest := sha256.Sum256(buf)
+ digest, err := hashData(buf, params.HashAlgorithms)
+ if err != nil {
+ return nil, err
+ }
if level == layout.rootLevel() {
- root = digest[:]
+ root = digest
}
// Write the generated hash to the end of the tree file.
@@ -241,13 +290,13 @@ func Generate(params *GenerateParams) ([]byte, error) {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
RootHash: root,
}
- ret := sha256.Sum256([]byte(descriptor.String()))
- return ret[:], nil
+ return hashData([]byte(descriptor.String()), params.HashAlgorithms)
}
// VerifyParams contains the params used to verify a portion of a file against
@@ -269,6 +318,8 @@ type VerifyParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// ReadOffset is the offset of the data range to be verified.
ReadOffset int64
// ReadSize is the size of the data range to be verified.
@@ -293,12 +344,13 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
RootHash: root,
}
- return descriptor.verify(params.Expected)
+ return descriptor.verify(params.Expected, params.HashAlgorithms)
}
// Verify verifies the content read from data with offset. The content is
@@ -313,7 +365,10 @@ func Verify(params *VerifyParams) (int64, error) {
if params.ReadSize < 0 {
return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
}
- layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return 0, err
+ }
if params.ReadSize == 0 {
return 0, verifyMetadata(params, &layout)
}
@@ -349,12 +404,13 @@ func Verify(params *VerifyParams) (int64, error) {
}
}
descriptor := VerityDescriptor{
- Name: params.Name,
- Mode: params.Mode,
- UID: params.UID,
- GID: params.GID,
+ Name: params.Name,
+ FileSize: params.Size,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
}
- if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil {
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
return 0, err
}
@@ -395,7 +451,7 @@ func Verify(params *VerifyParams) (int64, error) {
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
// expected.
-func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error {
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -406,8 +462,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
for level := 0; level < layout.numLevels(); level++ {
// Calculate hash.
if level == 0 {
- digestArray := sha256.Sum256(dataBlock)
- digest = digestArray[:]
+ h, err := hashData(dataBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
} else {
// Read a block in previous level that contains the
// hash we just generated, and generate a next level
@@ -415,8 +474,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil {
return err
}
- digestArray := sha256.Sum256(treeBlock)
- digest = digestArray[:]
+ h, err := hashData(treeBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
}
// Read the digest for the current block and store in
@@ -434,5 +496,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
// Verification for the tree succeeded. Now hash the descriptor with
// the root hash and compare it with expected.
descriptor.RootHash = digest
- return descriptor.verify(expected)
+ return descriptor.verify(expected, hashAlgorithms)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index e1350ebda..405204d94 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -22,54 +22,114 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
func TestLayout(t *testing.T) {
testCases := []struct {
dataSize int64
+ hashAlgorithms int
dataAndTreeInSameFile bool
+ expectedDigestSize int64
expectedLevelOffset []int64
}{
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0},
},
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
+ expectedLevelOffset: []int64{usermem.PageSize},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
expectedLevelOffset: []int64{usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize},
+ },
+ {
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize},
},
{
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize},
+ },
+ {
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize},
},
{
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize},
},
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize},
+ },
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
- l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile)
+ l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile)
+ if err != nil {
+ t.Fatalf("Failed to InitLayout: %v", err)
+ }
if l.blockSize != int64(usermem.PageSize) {
t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
}
- if l.digestSize != sha256DigestSize {
+ if l.digestSize != tc.expectedDigestSize {
t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
}
if l.numLevels() != len(tc.expectedLevelOffset) {
@@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
- data []byte
- expectedHash []byte
+ data []byte
+ hashAlgorithms int
+ expectedHash []byte
}{
{
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178},
},
{
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171},
},
{
- data: []byte{'a'},
- expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120},
},
{
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246},
+ },
+ {
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84},
},
}
@@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) {
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
+ HashAlgorithms: tc.hashAlgorithms,
TreeReader: &tree,
TreeWriter: &tree,
DataAndTreeInSameFile: dataAndTreeInSameFile,
@@ -189,6 +275,7 @@ func TestVerify(t *testing.T) {
// fail, otherwise Verify should still succeed.
modifyByte int64
modifyName bool
+ modifySize bool
modifyMode bool
modifyUID bool
modifyGID bool
@@ -237,6 +324,15 @@ func TestVerify(t *testing.T) {
modifyName: true,
shouldSucceed: false,
},
+ // Modified size should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifySize: true,
+ shouldSucceed: false,
+ },
// Modified mode should fail verification.
{
dataSize: usermem.PageSize,
@@ -348,77 +444,84 @@ func TestVerify(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Flip a bit in data and checks Verify results.
- var buf bytes.Buffer
- data[tc.modifyByte] ^= 1
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: tc.dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: tc.verifyStart,
- ReadSize: tc.verifySize,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if tc.modifyName {
- verifyParams.Name = defaultName + "abc"
- }
- if tc.modifyMode {
- verifyParams.Mode = defaultMode + 1
- }
- if tc.modifyUID {
- verifyParams.UID = defaultUID + 1
- }
- if tc.modifyGID {
- verifyParams.GID = defaultGID + 1
- }
- if tc.shouldSucceed {
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed when expected to succeed: %v", err)
+ // Flip a bit in data and checks Verify results.
+ var buf bytes.Buffer
+ data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
}
- if n != tc.verifySize {
- t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
}
- if int64(buf.Len()) != tc.verifySize {
- t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ if tc.modifySize {
+ verifyParams.Size--
}
- if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
}
- } else {
- if _, err := Verify(&verifyParams); err == nil {
- t.Errorf("Verification succeeded when expected to fail")
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
+ }
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
+ if tc.shouldSucceed {
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed when expected to succeed: %v", err)
+ }
+ if n != tc.verifySize {
+ t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ }
+ if int64(buf.Len()) != tc.verifySize {
+ t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ }
+ if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
+ } else {
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Errorf("Verification succeeded when expected to fail")
+ }
}
}
}
@@ -435,87 +538,91 @@ func TestVerifyRandom(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Pick a random portion of data.
- start := rand.Int63n(dataSize - 1)
- size := rand.Int63n(dataSize) + 1
+ // Pick a random portion of data.
+ start := rand.Int63n(dataSize - 1)
+ size := rand.Int63n(dataSize) + 1
- var buf bytes.Buffer
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: start,
- ReadSize: size,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: start,
+ ReadSize: size,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- // Checks that the random portion of data from the original data is
- // verified successfully.
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed for correct data: %v", err)
- }
- if size > dataSize-start {
- size = dataSize - start
- }
- if n != size {
- t.Errorf("Got Verify output size %d, want %d", n, size)
- }
- if int64(buf.Len()) != size {
- t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
- }
- if !bytes.Equal(data[start:start+size], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
- }
+ // Checks that the random portion of data from the original data is
+ // verified successfully.
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed for correct data: %v", err)
+ }
+ if size > dataSize-start {
+ size = dataSize - start
+ }
+ if n != size {
+ t.Errorf("Got Verify output size %d, want %d", n, size)
+ }
+ if int64(buf.Len()) != size {
+ t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
+ }
+ if !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
- // Verify that modified metadata should fail verification.
- buf.Reset()
- verifyParams.Name = defaultName + "abc"
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verify succeeded for modified metadata, expect failure")
- }
+ // Verify that modified metadata should fail verification.
+ buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
- // Flip a random bit in randPortion, and check that verification fails.
- buf.Reset()
- randBytePos := rand.Int63n(size)
- data[start+randBytePos] ^= 1
- verifyParams.File = bytes.NewReader(data)
- verifyParams.Name = defaultName
+ // Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
+ randBytePos := rand.Int63n(size)
+ data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verification succeeded for modified data, expect failure")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
+ }
}
}
}
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 699ea8ac3..6992e1de8 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -319,7 +319,8 @@ func makeStackKey(pcs []uintptr) stackKey {
return key
}
-func recordStack() []uintptr {
+// RecordStack constructs and returns the PCs on the current stack.
+func RecordStack() []uintptr {
pcs := make([]uintptr, maxStackFrames)
n := runtime.Callers(1, pcs)
if n == 0 {
@@ -342,7 +343,8 @@ func recordStack() []uintptr {
return v
}
-func formatStack(pcs []uintptr) string {
+// FormatStack converts the given stack into a readable format.
+func FormatStack(pcs []uintptr) string {
frames := runtime.CallersFrames(pcs)
var trace bytes.Buffer
for {
@@ -367,7 +369,7 @@ func (r *AtomicRefCount) finalize() {
if n := r.ReadRefs(); n != 0 {
msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n)
if len(r.stack) != 0 {
- msg += ":\nCaller:\n" + formatStack(r.stack)
+ msg += ":\nCaller:\n" + FormatStack(r.stack)
} else {
msg += " (enable trace logging to debug)"
}
@@ -392,7 +394,7 @@ func (r *AtomicRefCount) EnableLeakCheck(name string) {
case NoLeakChecking:
return
case LeaksLogTraces:
- r.stack = recordStack()
+ r.stack = RecordStack()
}
r.name = name
runtime.SetFinalizer(r, (*AtomicRefCount).finalize)
diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD
index 577b827a5..bfa1daa10 100644
--- a/pkg/refs_vfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -8,6 +8,9 @@ go_template(
srcs = [
"refs_template.go",
],
+ opt_consts = [
+ "logTrace",
+ ],
types = [
"T",
],
@@ -19,8 +22,16 @@ go_template(
)
go_library(
- name = "refs_vfs2",
- srcs = ["refs.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/context"],
+ name = "refsvfs2",
+ srcs = [
+ "refs.go",
+ "refs_map.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go
index 99a074e96..ef8beb659 100644
--- a/pkg/refs_vfs2/refs.go
+++ b/pkg/refsvfs2/refs.go
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package refs_vfs2 defines an interface for a reference-counted object.
-package refs_vfs2
+// Package refsvfs2 defines an interface for a reference-counted object.
+package refsvfs2
import (
"gvisor.dev/gvisor/pkg/context"
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
new file mode 100644
index 000000000..9fbc5466f
--- /dev/null
+++ b/pkg/refsvfs2/refs_map.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refsvfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var (
+ // liveObjects is a global map of reference-counted objects. Objects are
+ // inserted when leak check is enabled, and they are removed when they are
+ // destroyed. It is protected by liveObjectsMu.
+ liveObjects map[CheckedObject]struct{}
+ liveObjectsMu sync.Mutex
+)
+
+// CheckedObject represents a reference-counted object with an informative
+// leak detection message.
+type CheckedObject interface {
+ // RefType is the type of the reference-counted object.
+ RefType() string
+
+ // LeakMessage supplies a warning to be printed upon leak detection.
+ LeakMessage() string
+
+ // LogRefs indicates whether reference-related events should be logged.
+ LogRefs() bool
+}
+
+func init() {
+ liveObjects = make(map[CheckedObject]struct{})
+}
+
+// leakCheckEnabled returns whether leak checking is enabled. The following
+// functions should only be called if it returns true.
+func leakCheckEnabled() bool {
+ return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking
+}
+
+// Register adds obj to the live object map.
+func Register(obj CheckedObject) {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ if _, ok := liveObjects[obj]; ok {
+ panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj))
+ }
+ liveObjects[obj] = struct{}{}
+ liveObjectsMu.Unlock()
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "registered")
+ }
+ }
+}
+
+// Unregister removes obj from the live object map.
+func Unregister(obj CheckedObject) {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ if _, ok := liveObjects[obj]; !ok {
+ panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj))
+ }
+ delete(liveObjects, obj)
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "unregistered")
+ }
+ }
+}
+
+// LogIncRef logs a reference increment.
+func LogIncRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("IncRef to %d", refs))
+ }
+}
+
+// LogTryIncRef logs a successful TryIncRef call.
+func LogTryIncRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs))
+ }
+}
+
+// LogDecRef logs a reference decrement.
+func LogDecRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("DecRef to %d", refs))
+ }
+}
+
+// logEvent logs a message for the given reference-counted object.
+//
+// obj.LogRefs() should be checked before calling logEvent, in order to avoid
+// calling any text processing needed to evaluate msg.
+func logEvent(obj CheckedObject, msg string) {
+ log.Infof("[%s %p] %s:", obj.RefType(), obj, msg)
+ log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
+}
+
+// DoLeakCheck iterates through the live object map and logs a message for each
+// object. It is called once no reference-counted objects should be reachable
+// anymore, at which point anything left in the map is considered a leak.
+func DoLeakCheck() {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ leaked := len(liveObjects)
+ if leaked > 0 {
+ log.Warningf("Leak checking detected %d leaked objects:", leaked)
+ for obj := range liveObjects {
+ log.Warningf(obj.LeakMessage())
+ }
+ }
+ }
+}
diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index d9b552896..8f50b4ee6 100644
--- a/pkg/refs_vfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -21,20 +21,24 @@ package refs_template
import (
"fmt"
- "runtime"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/log"
- refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
)
+// enableLogging indicates whether reference-related events should be logged (with
+// stack traces). This is false by default and should only be set to true for
+// debugging purposes, as it can generate an extremely large amount of output
+// and drastically degrade performance.
+const enableLogging = false
+
// T is the type of the reference counted object. It is only used to customize
// debug output when leak checking.
type T interface{}
-// ownerType is used to customize logging. Note that we use a pointer to T so
-// that we do not copy the entire object when passed as a format parameter.
-var ownerType *T
+// obj is used to customize logging. Note that we use a pointer to T so that
+// we do not copy the entire object when passed as a format parameter.
+var obj *T
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
@@ -42,11 +46,6 @@ var ownerType *T
// Note that the number of references is actually refCount + 1 so that a default
// zero-value Refs object contains one reference.
//
-// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in
-// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount.
-// This will allow us to add stack trace information to the leak messages
-// without growing the size of Refs.
-//
// +stateify savable
type Refs struct {
// refCount is composed of two fields:
@@ -59,24 +58,24 @@ type Refs struct {
refCount int64
}
-func (r *Refs) finalize() {
- var note string
- switch refs_vfs1.GetLeakMode() {
- case refs_vfs1.NoLeakChecking:
- return
- case refs_vfs1.UninitializedLeakChecking:
- note = "(Leak checker uninitialized): "
- }
- if n := r.ReadRefs(); n != 0 {
- log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n)
- }
+// RefType implements refsvfs2.CheckedObject.RefType.
+func (r *Refs) RefType() string {
+ return fmt.Sprintf("%T", obj)[1:]
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (r *Refs) LeakMessage() string {
+ return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
}
-// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+func (r *Refs) LogRefs() bool {
+ return enableLogging
+}
+
+// EnableLeakCheck enables reference leak checking on r.
func (r *Refs) EnableLeakCheck() {
- if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
- runtime.SetFinalizer(r, (*Refs).finalize)
- }
+ refsvfs2.Register(r)
}
// ReadRefs returns the current number of references. The returned count is
@@ -90,8 +89,10 @@ func (r *Refs) ReadRefs() int64 {
//
//go:nosplit
func (r *Refs) IncRef() {
- if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
- panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType))
+ v := atomic.AddInt64(&r.refCount, 1)
+ refsvfs2.LogIncRef(r, v+1)
+ if v <= 0 {
+ panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
@@ -104,15 +105,15 @@ func (r *Refs) IncRef() {
//go:nosplit
func (r *Refs) TryIncRef() bool {
const speculativeRef = 1 << 32
- v := atomic.AddInt64(&r.refCount, speculativeRef)
- if int32(v) < 0 {
+ if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 {
// This object has already been freed.
atomic.AddInt64(&r.refCount, -speculativeRef)
return false
}
// Turn into a real reference.
- atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ v := atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ refsvfs2.LogTryIncRef(r, v+1)
return true
}
@@ -129,14 +130,23 @@ func (r *Refs) TryIncRef() bool {
//
//go:nosplit
func (r *Refs) DecRef(destroy func()) {
- switch v := atomic.AddInt64(&r.refCount, -1); {
+ v := atomic.AddInt64(&r.refCount, -1)
+ refsvfs2.LogDecRef(r, v+1)
+ switch {
case v < -1:
- panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType))
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
case v == -1:
+ refsvfs2.Unregister(r)
// Call the destructor.
if destroy != nil {
destroy()
}
}
}
+
+func (r *Refs) afterLoad() {
+ if r.ReadRefs() > 0 {
+ r.EnableLeakCheck()
+ }
+}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index 41feeffe3..d800f2c85 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
s.Kernel.Kill(kernel.ExitStatus{})
},
}
- return saveOpts.Save(s.Kernel, s.Watchdog)
+ return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog)
}
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 655ea549b..ff5d49fbd 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -39,6 +39,8 @@ const (
)
// tunDevice implements vfs.Device for /dev/net/tun.
+//
+// +stateify savable
type tunDevice struct{}
// Open implements vfs.Device.Open.
@@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt
}
// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+//
+// +stateify savable
type tunFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 1390a9a7f..4468f5dd2 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() {
f.mappings = make(map[uint64]mapping)
}
+// IsInited returns true if f.Init() has been called. This is used when
+// restoring a checkpoint that contains a HostFileMapper that may or may not
+// have been initialized.
+func (f *HostFileMapper) IsInited() bool {
+ return f.refs != nil
+}
+
// NewHostFileMapper returns an initialized HostFileMapper allocated on the
// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 3c66dc3c2..6b3627813 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
@@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode,
}
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Wrap the fileOps with our Fifo.
iops := &fifo{
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index e555672ad..52061175f 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
}
// GetFile implements fs.InodeOperations.GetFile.
-func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
- return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil
+ return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil
}
// +stateify savable
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 22d658acf..450044c9c 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo
"gid_map": newGIDMap(t, msrc),
"io": newIO(t, msrc, isThreadGroup),
"maps": newMaps(t, msrc),
+ "mem": newMem(t, msrc),
"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
"net": newNetDir(t, msrc),
@@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
+// memData implements fs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memData struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// memDataFile implements fs.FileOperations for /proc/[pid]/mem.
+//
+// +stateify savable
+type memDataFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ inode := &memData{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, inode, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (m *memData) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, m.t, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(m.t); err != nil {
+ return nil, err
+ }
+ // Enable random access reads
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ mm, err := getTaskMM(m.t)
+ if err != nil {
+ return 0, nil
+ }
+ defer mm.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
// mapsData implements seqfile.SeqSource for /proc/[pid]/maps.
//
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index fc0498f17..d6c65301c 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
@@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
// Write to that memory as usual.
seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 998b697ca..cf4ed5de0 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -336,7 +336,7 @@ type Fifo struct {
// NewFifo creates a new named pipe.
func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Build pipe InodeOperations.
iops := pipe.NewInodeOperations(ctx, perms, p)
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 84baaac66..6af3c3781 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "root_inode_refs.go",
package = "devpts",
prefix = "rootInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "rootInode",
},
@@ -33,6 +33,7 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index d5c5aaa8c..346cca558 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir
}
fstype.initOnce.Do(func() {
- fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds)
if err != nil {
fstype.initErr = err
return
@@ -93,7 +93,7 @@ type filesystem struct {
// newFilesystem creates a new devpts filesystem with root directory and ptmx
// master inode. It returns the filesystem and root Dentry.
-func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -108,19 +108,19 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root := &rootInode{
replicas: make(map[uint32]*replicaInode),
}
- root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
+ root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
root.EnableLeakCheck()
var rootD kernfs.Dentry
- rootD.Init(&fs.Filesystem, root)
+ rootD.InitRoot(&fs.Filesystem, root)
// Construct the pts master inode and dentry. Linux always uses inode
// id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
master := &masterInode{
root: root,
}
- master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
+ master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
// Add the master as a child of the root.
links := root.OrderedChildren.Populate(map[string]kernfs.Inode{
@@ -170,7 +170,7 @@ type rootInode struct {
var _ kernfs.Inode = (*rootInode)(nil)
// allocateTerminal creates a new Terminal and installs a pts node for it.
-func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) {
i.mu.Lock()
defer i.mu.Unlock()
if i.nextIdx == math.MaxUint32 {
@@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
}
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
- replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
i.replicas[idx] = replica
return t, nil
@@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
i.mu.Lock()
defer i.mu.Unlock()
+ i.InodeAttrs.TouchAtime(ctx, mnt)
ids := make([]int, 0, len(i.replicas))
for id := range i.replicas {
ids = append(ids, int(id))
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
index e6b0e81cf..ae95fdd08 100644
--- a/pkg/sentry/fsimpl/devpts/line_discipline.go
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -100,10 +100,10 @@ type lineDiscipline struct {
column int
// masterWaiter is used to wait on the master end of the TTY.
- masterWaiter waiter.Queue `state:"zerovalue"`
+ masterWaiter waiter.Queue
// replicaWaiter is used to wait on the replica end of the TTY.
- replicaWaiter waiter.Queue `state:"zerovalue"`
+ replicaWaiter waiter.Queue
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index fda30fb93..e91fa26a4 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil)
// Open implements kernfs.Inode.Open.
func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- t, err := mi.root.allocateTerminal(rp.Credentials())
+ t, err := mi.root.allocateTerminal(ctx, rp.Credentials())
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD
index 01bbee5ad..e49a04c1b 100644
--- a/pkg/sentry/fsimpl/devtmpfs/BUILD
+++ b/pkg/sentry/fsimpl/devtmpfs/BUILD
@@ -4,7 +4,10 @@ licenses(["notice"])
go_library(
name = "devtmpfs",
- srcs = ["devtmpfs.go"],
+ srcs = [
+ "devtmpfs.go",
+ "save_restore.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
new file mode 100644
index 000000000..28832d850
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devtmpfs
+
+// afterLoad is invoked by stateify.
+func (fst *FilesystemType) afterLoad() {
+ if fst.fs != nil {
+ // Ensure that we don't create another filesystem.
+ fst.initOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 1c27ad700..5b29f2358 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -43,7 +43,7 @@ type EventFileDescription struct {
// queue is used to notify interested parties when the event object
// becomes readable or writable.
- queue waiter.Queue `state:"zerovalue"`
+ queue waiter.Queue
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 045d7ab08..2158b1bbc 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "inode_refs.go",
package = "fuse",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -49,6 +49,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/kernfs",
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 5986133e9..95c475a65 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
readPayload.MarshalUnsafe(outBuf[outHdrLen:])
outIOseq := usermem.BytesIOSequence(outBuf)
- n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
if err != nil {
t.Fatalf("Write failed :%v", err)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index e39df21c6..6de416da0 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// root is the fusefs root directory.
- root := fs.newRootInode(creds, fsopts.rootMode)
+ root := fs.newRoot(ctx, creds, fsopts.rootMode)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
@@ -284,21 +284,21 @@ type inode struct {
link string
}
-func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
i := &inode{fs: fs, nodeID: 1}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
+ i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
var d kernfs.Dentry
- d.Init(&fs.Filesystem, i)
+ d.InitRoot(&fs.Filesystem, i)
return &d
}
-func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
+func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
i := &inode{fs: fs, nodeID: nodeID}
creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
- i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
+ i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
@@ -424,7 +424,7 @@ func (i *inode) Keep() bool {
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
@@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) {
return nil, syserror.EIO
}
- child := i.fs.newInode(out.NodeID, out.Attr)
+ child := i.fs.newInode(ctx, out.NodeID, out.Attr)
return child, nil
}
@@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return linux.FUSEAttr{}, err
@@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 625d1547f..2d396e84c 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u
// May need to update the signature.
i := fd.inode()
- // TODO(gvisor.dev/issue/1193): Invalidate or update atime.
+ i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount())
// Reached EOF.
if sizeRead < size {
@@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
Flags: fd.statusFlags(),
}
+ inode := fd.inode()
var written uint32
// This loop is intended for fragmented write where the bytes to write is
@@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in)
+ req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
if err != nil {
return 0, err
}
@@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
break
}
}
+ inode.InodeAttrs.TouchCMtime(ctx)
return written, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index ad0afc41b..4c3e9acf8 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"host_named_pipe.go",
"p9file.go",
"regular_file.go",
+ "save_restore.go",
"socket.go",
"special_file.go",
"symlink.go",
@@ -53,6 +54,7 @@ go_library(
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
@@ -70,6 +72,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 18c884b59..ce1b2a390 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -16,16 +16,17 @@ package gofer
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
child := &dentry{
refs: 1, // held by d
fs: d.fs,
- ino: d.fs.nextSyntheticIno(),
+ ino: d.fs.nextIno(),
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
@@ -100,6 +101,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
hostFD: -1,
nlink: uint32(2),
}
+ refsvfs2.Register(child)
switch opts.mode.FileType() {
case linux.S_IFDIR:
// Nothing else needs to be done.
@@ -235,7 +237,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
}
dirent := vfs.Dirent{
Name: p9d.Name,
- Ino: uint64(inoFromPath(p9d.QID.Path)),
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
NextOff: int64(len(dirents) + 1),
}
// p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 94d96261b..bbb01148b 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -30,12 +30,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
- // Snapshot current syncable dentries and special files.
+ // Snapshot current syncable dentries and special file FDs.
fs.syncMu.Lock()
ds := make([]*dentry, 0, len(fs.syncableDentries))
for d := range fs.syncableDentries {
@@ -53,22 +52,28 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
- // Sync regular files.
+ // Sync syncable dentries.
for _, d := range ds {
- err := d.syncCachedFile(ctx)
+ err := d.syncCachedFile(ctx, true /* forFilesystemSync */)
d.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- err := sffd.Sync(ctx)
+ err := sffd.sync(ctx, true /* forFilesystemSync */)
sffd.vfsfd.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
@@ -229,7 +234,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
return nil, err
}
if child != nil {
- if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ if !file.isNil() && qid.Path == child.qidPath {
// The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
child.updateFromP9AttrsLocked(attrMask, &attr)
@@ -256,7 +261,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// treat their invalidation as deletion.
child.setDeleted()
parent.syntheticChildren--
- child.decRefLocked()
+ child.decRefNoCaching()
parent.dirents = nil
}
*ds = appendDentry(*ds, child)
@@ -366,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if len(name) > maxFilenameLen {
return syserror.ENAMETOOLONG
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.isDeleted() {
return syserror.ENOENT
}
@@ -383,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if createInSyntheticDir == nil {
return syserror.EPERM
}
@@ -402,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil && child.isSynthetic() {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// The existence of a non-synthetic dentry at name would be inconclusive
// because the file it represents may have been deleted from the remote
// filesystem, so we would need to make an RPC to revalidate the dentry.
@@ -422,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// No cached dentry exists; however, there might still be an existing file
// at name. As above, we attempt the file creation RPC anyway.
if err := createInRemoteDir(parent, name, &ds); err != nil {
@@ -625,7 +636,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
child.setDeleted()
if child.isSynthetic() {
parent.syntheticChildren--
- child.decRefLocked()
+ child.decRefNoCaching()
}
ds = appendDentry(ds, child)
}
@@ -836,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
mode: opts.Mode,
kuid: creds.EffectiveKUID,
kgid: creds.EffectiveKGID,
- pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
return nil
}
@@ -1355,7 +1366,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
replaced.setDeleted()
if replaced.isSynthetic() {
newParent.syntheticChildren--
- replaced.decRefLocked()
+ replaced.decRefNoCaching()
}
ds = appendDentry(ds, replaced)
}
@@ -1364,7 +1375,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// with reference counts and queue oldParent for checkCachingLocked if the
// parent isn't actually changing.
if oldParent != newParent {
- oldParent.decRefLocked()
+ oldParent.decRefNoCaching()
ds = appendDentry(ds, oldParent)
newParent.IncRef()
if renamed.isSynthetic() {
@@ -1512,7 +1523,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
d.IncRef()
return &endpoint{
dentry: d,
- file: d.file.file,
path: opts.Addr,
}, nil
}
@@ -1591,7 +1601,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
-
-func (fs *filesystem) nextSyntheticIno() inodeNumber {
- return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
-}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index f1dad1b08..6f82ce61b 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -26,6 +26,9 @@
// *** "memmap.Mappable locks taken by Translate" below this point
// dentry.handleMu
// dentry.dataMu
+// filesystem.inoMu
+// specialFileFD.mu
+// specialFileFD.bufMu
//
// Locking dentry.dirMu in multiple dentries requires that either ancestor
// dentries are locked before descendant dentries, or that filesystem.renameMu
@@ -36,7 +39,6 @@ import (
"fmt"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
@@ -44,6 +46,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -53,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/usermem"
@@ -81,7 +86,7 @@ type filesystem struct {
iopts InternalFilesystemOptions
// client is the client used by this filesystem. client is immutable.
- client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ client *p9.Client `state:"nosave"`
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -89,6 +94,9 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
+ // root is the root dentry. root is immutable.
+ root *dentry
+
// renameMu serves two purposes:
//
// - It synchronizes path resolution with renaming initiated by this
@@ -103,39 +111,35 @@ type filesystem struct {
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
- // cachedDentriesLen is the number of dentries in cachedDentries. These
- // fields are protected by renameMu.
+ // cachedDentriesLen is the number of dentries in cachedDentries. These fields
+ // are protected by renameMu.
cachedDentries dentryList
cachedDentriesLen uint64
- // syncableDentries contains all dentries in this filesystem for which
- // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
- // These fields are protected by syncMu.
+ // syncableDentries contains all non-synthetic dentries. specialFileFDs
+ // contains all open specialFileFDs. These fields are protected by syncMu.
syncMu sync.Mutex `state:"nosave"`
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
- // syntheticSeq stores a counter to used to generate unique inodeNumber for
- // synthetic dentries.
- syntheticSeq uint64
-}
+ // inoByQIDPath maps previously-observed QID.Paths to inode numbers
+ // assigned to those paths. inoByQIDPath is not preserved across
+ // checkpoint/restore because QIDs may be reused between different gofer
+ // processes, so QIDs may be repeated for different files across
+ // checkpoint/restore. inoByQIDPath is protected by inoMu.
+ inoMu sync.Mutex `state:"nosave"`
+ inoByQIDPath map[uint64]uint64 `state:"nosave"`
-// inodeNumber represents inode number reported in Dirent.Ino. For regular
-// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
-// have have their inodeNumber generated sequentially, with the MSB reserved to
-// prevent conflicts with regular dentries.
-//
-// +stateify savable
-type inodeNumber uint64
+ // lastIno is the last inode number assigned to a file. lastIno is accessed
+ // using atomic memory operations.
+ lastIno uint64
-// Reserve MSB for synthetic mounts.
-const syntheticInoMask = uint64(1) << 63
+ // savedDentryRW records open read/write handles during save/restore.
+ savedDentryRW map[*dentry]savedDentryRW
-func inoFromPath(path uint64) inodeNumber {
- if path&syntheticInoMask != 0 {
- log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
- }
- return inodeNumber(path &^ syntheticInoMask)
+ // released is nonzero once filesystem.Release has been called. It is accessed
+ // with atomic memory operations.
+ released int32
}
// +stateify savable
@@ -149,8 +153,7 @@ type filesystemOptions struct {
msize uint32
version string
- // maxCachedDentries is the maximum number of dentries with 0 references
- // retained by the client.
+ // maxCachedDentries is the maximum size of filesystem.cachedDentries.
maxCachedDentries uint64
// If forcePageCache is true, host FDs may not be used for application
@@ -247,6 +250,10 @@ const (
//
// +stateify savable
type InternalFilesystemOptions struct {
+ // If UniqueID is non-empty, it is an opaque string used to reassociate the
+ // filesystem with a new server FD during restoration from checkpoint.
+ UniqueID string
+
// If LeakConnection is true, do not close the connection to the server
// when the Filesystem is released. This is necessary for deployments in
// which servers can handle only a single client and report failure if that
@@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mopts := vfs.GenericParseMountOptions(opts.Data)
var fsopts filesystemOptions
- // Check that the transport is "fd".
- trans, ok := mopts["trans"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "trans")
- if trans != "fd" {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
- return nil, nil, syserror.EINVAL
- }
-
- // Check that read and write FDs are provided and identical.
- rfdstr, ok := mopts["rfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "rfdno")
- rfd, err := strconv.Atoi(rfdstr)
+ fd, err := getFDFromMountOptionsMap(ctx, mopts)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
- return nil, nil, syserror.EINVAL
- }
- wfdstr, ok := mopts["wfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "wfdno")
- wfd, err := strconv.Atoi(wfdstr)
- if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
- return nil, nil, syserror.EINVAL
- }
- if rfd != wfd {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
- return nil, nil, syserror.EINVAL
+ return nil, nil, err
}
- fsopts.fd = rfd
+ fsopts.fd = fd
// Get the attach name.
fsopts.aname = "/"
@@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// If !ok, iopts being the zero value is correct.
- // Establish a connection with the server.
- conn, err := unet.NewSocket(fsopts.fd)
+ // Construct the filesystem object.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ iopts: iopts,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ devMinor: devMinor,
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
- // Perform version negotiation with the server.
- ctx.UninterruptibleSleepStart(false)
- client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- conn.Close()
+ // Connect to the server.
+ if err := fs.dial(ctx); err != nil {
return nil, nil, err
}
- // Ownership of conn has been transferred to client.
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fsopts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
- // Construct the filesystem object.
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- attachFile.close(ctx)
- client.Close()
- return nil, nil, err
- }
- fs := &filesystem{
- mfp: mfp,
- opts: fsopts,
- iopts: iopts,
- client: client,
- clock: ktime.RealtimeClockFromContext(ctx),
- devMinor: devMinor,
- syncableDentries: make(map[*dentry]struct{}),
- specialFileFDs: make(map[*specialFileFD]struct{}),
- }
- fs.vfsfs.Init(vfsObj, &fstype, fs)
-
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
@@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
// Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is deliberately leaked to prevent the root from
- // being "cached" and subsequently evicted. Its resources will still be
- // cleaned up by fs.Release().
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
root.refs = 2
+ fs.root = root
return &fs.vfsfs, &root.vfsd, nil
}
+func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok || trans != "fd" {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr)
+ return -1, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr)
+ return -1, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return -1, syserror.EINVAL
+ }
+ return rfd, nil
+}
+
+// Preconditions: fs.client == nil.
+func (fs *filesystem) dial(ctx context.Context) error {
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return err
+ }
+ // Ownership of conn has been transferred to client.
+
+ fs.client = client
+ return nil
+}
+
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
- mf := fs.mfp.MemoryFile()
+ atomic.StoreInt32(&fs.released, 1)
+ mf := fs.mfp.MemoryFile()
fs.syncMu.Lock()
for d := range fs.syncableDentries {
d.handleMu.Lock()
d.dataMu.Lock()
if h := d.writeHandleLocked(); h.isOpen() {
// Write dirty cached data to the remote file.
- if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil {
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
}
// TODO(jamieliu): Do we need to flushf/fsync d?
@@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) {
// fs.
fs.syncMu.Unlock()
+ // If leak checking is enabled, release all outstanding references in the
+ // filesystem. We deliberately avoid doing this outside of leak checking; we
+ // have released all external resources above rather than relying on dentry
+ // destructors.
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ fs.renameMu.Lock()
+ fs.root.releaseSyntheticRecursiveLocked(ctx)
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // An extra reference was held by the filesystem on the root to prevent it from
+ // being cached/evicted.
+ fs.root.DecRef(ctx)
+ }
+
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
fs.client.Close()
@@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
}
+// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements
+// the reference count on every synthetic dentry. Synthetic dentries have one
+// reference for existence that should be dropped during filesystem.Release.
+//
+// Precondition: d.fs.renameMu is locked.
+func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
+ if d.isSynthetic() {
+ d.decRefNoCaching()
+ d.checkCachingLocked(ctx)
+ }
+ if d.isDir() {
+ var children []*dentry
+ d.dirMu.Lock()
+ for _, child := range d.children {
+ children = append(children, child)
+ }
+ d.dirMu.Unlock()
+ for _, child := range children {
+ if child != nil {
+ child.releaseSyntheticRecursiveLocked(ctx)
+ }
+ }
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -574,12 +635,15 @@ type dentry struct {
// filesystem.renameMu.
name string
+ // qidPath is the p9.QID.Path for this file. qidPath is immutable.
+ qidPath uint64
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
// that does not exist on the remote filesystem. As of this writing, the
// only files that can be synthetic are sockets, pipes, and directories.
- file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ file p9file `state:"nosave"`
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
@@ -623,12 +687,12 @@ type dentry struct {
// To mutate:
// - Lock metadataMu and use atomic operations to update because we might
// have atomic readers that don't hold the lock.
- metadataMu sync.Mutex `state:"nosave"`
- ino inodeNumber // immutable
- mode uint32 // type is immutable, perms are mutable
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- blockSize uint32 // 0 if unknown
+ metadataMu sync.Mutex `state:"nosave"`
+ ino uint64 // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
// Timestamps, all nsecs from the Unix epoch.
atime int64
mtime int64
@@ -679,9 +743,9 @@ type dentry struct {
// (isNil() == false), it may be mutated with handleMu locked, but cannot
// be closed until the dentry is destroyed.
handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- hostFD int32
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ hostFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d := &dentry{
fs: fs,
+ qidPath: qid.Path,
file: file,
- ino: inoFromPath(qid.Path),
+ ino: fs.inoFromQIDPath(qid.Path),
mode: uint32(attr.Mode),
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
@@ -795,13 +860,28 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d.nlink = uint32(attr.NLink)
}
d.vfsd.Init(d)
-
+ refsvfs2.Register(d)
fs.syncMu.Lock()
fs.syncableDentries[d] = struct{}{}
fs.syncMu.Unlock()
return d, nil
}
+func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+ if ino, ok := fs.inoByQIDPath[qidPath]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByQIDPath[qidPath] = ino
+ return ino
+}
+
+func (fs *filesystem) nextIno() uint64 {
+ return atomic.AddUint64(&fs.lastIno, 1)
+}
+
func (d *dentry) isSynthetic() bool {
return d.file.isNil()
}
@@ -853,7 +933,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
-// Preconditions: !d.isSynthetic()
+// Preconditions: !d.isSynthetic().
func (d *dentry) updateFromGetattr(ctx context.Context) error {
// Use d.readFile or d.writeFile, which represent 9P fids that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
@@ -916,10 +996,10 @@ func (d *dentry) statTo(stat *linux.Statx) {
// This is consistent with regularFileFD.Seek(), which treats regular files
// as having no holes.
stat.Blocks = (stat.Size + 511) / 512
- stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
- stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
- stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
- stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime))
+ stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime))
stat.DevMajor = linux.UNNAMED_MAJOR
stat.DevMinor = d.fs.devMinor
}
@@ -967,10 +1047,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// Use client clocks for timestamps.
now = d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW {
- stat.Atime = statxTimestampFromDentry(now)
+ stat.Atime = linux.NsecToStatxTimestamp(now)
}
if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW {
- stat.Mtime = statxTimestampFromDentry(now)
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
}
}
@@ -1029,11 +1109,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
if stat.Mask&linux.STATX_ATIME != 0 {
- atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
if stat.Mask&linux.STATX_MTIME != 0 {
- atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
@@ -1139,17 +1219,19 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 {
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkCachingLocked().
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -1157,22 +1239,41 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ if d.decRefNoCaching() == 0 {
d.fs.renameMu.Lock()
d.checkCachingLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
- panic("gofer.dentry.DecRef() called without holding a reference")
}
}
-// decRefLocked decrements d's reference count without calling
+// decRefNoCaching decrements d's reference count without calling
// d.checkCachingLocked, even if d's reference count reaches 0; callers are
// responsible for ensuring that d.checkCachingLocked will be called later.
-func (d *dentry) decRefLocked() {
- if refs := atomic.AddInt64(&d.refs, -1); refs < 0 {
- panic("gofer.dentry.decRefLocked() called without holding a reference")
+func (d *dentry) decRefNoCaching() int64 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r < 0 {
+ panic("gofer.dentry.decRefNoCaching() called without holding a reference")
}
+ return r
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "gofer.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -1223,6 +1324,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
// resolution, which requires renameMu, so if d.refs is zero then it will
// remain zero while we hold renameMu for writing.)
refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
if refs > 0 {
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1231,10 +1336,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
}
return
}
- if refs == -1 {
- // Dentry has already been destroyed.
- return
- }
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
@@ -1257,6 +1358,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
if d.watches.Size() > 0 {
return
}
+
+ if atomic.LoadInt32(&d.fs.released) != 0 {
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ delete(d.parent.children, d.name)
+ d.parent.dirMu.Unlock()
+ }
+ d.destroyLocked(ctx)
+ }
+
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1269,33 +1380,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
d.fs.cachedDentriesLen++
d.cached = true
if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
- victim := d.fs.cachedDentries.Back()
- d.fs.cachedDentries.Remove(victim)
- d.fs.cachedDentriesLen--
- victim.cached = false
- // victim.refs may have become non-zero from an earlier path resolution
- // since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 {
- if victim.parent != nil {
- victim.parent.dirMu.Lock()
- if !victim.vfsd.IsDead() {
- // Note that victim can't be a mount point (in any mount
- // namespace), since VFS holds references on mount points.
- d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
- delete(victim.parent.children, victim.name)
- // We're only deleting the dentry, not the file it
- // represents, so we don't need to update
- // victimParent.dirents etc.
- }
- victim.parent.dirMu.Unlock()
- }
- victim.destroyLocked(ctx)
- }
+ d.fs.evictCachedDentryLocked(ctx)
// Whether or not victim was destroyed, we brought fs.cachedDentriesLen
// back down to fs.opts.maxCachedDentries, so we don't loop.
}
}
+// Precondition: fs.renameMu must be locked for writing; it may be temporarily
+// unlocked.
+func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+}
+
+// Preconditions:
+// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// * fs.cachedDentriesLen != 0.
+func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ victim := fs.cachedDentries.Back()
+ fs.cachedDentries.Remove(victim)
+ fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+}
+
// destroyLocked destroys the dentry.
//
// Preconditions:
@@ -1373,13 +1499,10 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// Drop the reference held by d on its parent without recursively locking
// d.fs.renameMu.
- if d.parent != nil {
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkCachingLocked(ctx)
- } else if refs < 0 {
- panic("gofer.dentry.DecRef() called without holding a reference")
- }
+ if d.parent != nil && d.parent.decRefNoCaching() == 0 {
+ d.parent.checkCachingLocked(ctx)
}
+ refsvfs2.Unregister(d)
}
func (d *dentry) isDeleted() bool {
@@ -1623,6 +1746,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
return nil
}
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ if h.isOpen() {
+ // Write back dirty pages to the remote file.
+ d.dataMu.Lock()
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (d is a regular file and was opened for writing).
+ if d.isRegularFile() && h.isOpen() {
+ return err
+ }
+ ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err)
+ }
+ return nil
+}
+
// incLinks increments link count.
func (d *dentry) incLinks() {
if atomic.LoadUint32(&d.nlink) == 0 {
@@ -1650,7 +1800,7 @@ type fileDescription struct {
vfs.FileDescriptionDefaultImpl
vfs.LockFD
- lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ lockLogging sync.Once `state:"nosave"`
}
func (fd *fileDescription) filesystem() *filesystem {
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index bfe75dfe4..76f08e252 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -26,12 +26,13 @@ import (
func TestDestroyIdempotent(t *testing.T) {
ctx := contexttest.Context(t)
fs := filesystem{
- mfp: pgalloc.MemoryFileProviderFromContext(ctx),
- syncableDentries: make(map[*dentry]struct{}),
+ mfp: pgalloc.MemoryFileProviderFromContext(ctx),
opts: filesystemOptions{
// Test relies on no dentry being held in the cache.
maxCachedDentries: 0,
},
+ syncableDentries: make(map[*dentry]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
index 7294de7d6..c7bf10007 100644
--- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
if ok {
return nil
}
- if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
- return err
+ if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil {
+ // Another application thread may have opened this pipe for
+ // writing, succeeded because we previously opened the pipe for
+ // reading, and subsequently interrupted us for checkpointing (e.g.
+ // this occurs in mknod tests under cooperative save/restore). In
+ // this case, our open has to succeed for the checkpoint to include
+ // a readable FD for the pipe, which is in turn necessary to
+ // restore the other thread's writable FD for the same pipe
+ // (otherwise it will get ENXIO). So we have to check
+ // nonblockingPipeHasWriter() once last time.
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ return sleepErr
}
}
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f8b19bae7..dc8a890cb 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -18,7 +18,6 @@ import (
"fmt"
"io"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx)
-}
-
-func (d *dentry) syncCachedFile(ctx context.Context) error {
- d.handleMu.RLock()
- defer d.handleMu.RUnlock()
-
- if h := d.writeHandleLocked(); h.isOpen() {
- d.dataMu.Lock()
- // Write dirty cached data to the remote file.
- err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
- d.dataMu.Unlock()
- if err != nil {
- return err
- }
- }
- return d.syncRemoteFileLocked(ctx)
+ return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
@@ -913,7 +897,7 @@ type dentryPlatformFile struct {
hostFileMapper fsutil.HostFileMapper
// hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
- hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ hostFileMapperInitOnce sync.Once `state:"nosave"`
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
new file mode 100644
index 000000000..17849dcc0
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -0,0 +1,329 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type saveRestoreContextID int
+
+const (
+ // CtxRestoreServerFDMap is a Context.Value key for a map[string]int
+ // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID)
+ // to host FDs.
+ CtxRestoreServerFDMap saveRestoreContextID = iota
+)
+
+// +stateify savable
+type savedDentryRW struct {
+ read bool
+ write bool
+}
+
+// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave.
+func (fs *filesystem) PrepareSave(ctx context.Context) error {
+ if len(fs.iopts.UniqueID) == 0 {
+ return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved")
+ }
+
+ // Purge cached dentries, which may not be reopenable after restore due to
+ // permission changes.
+ fs.renameMu.Lock()
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // Buffer pipe data so that it's available for reading after restore. (This
+ // is a legacy VFS1 feature.)
+ fs.syncMu.Lock()
+ for sffd := range fs.specialFileFDs {
+ if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() {
+ if err := sffd.savePipeData(ctx); err != nil {
+ fs.syncMu.Unlock()
+ return err
+ }
+ }
+ }
+ fs.syncMu.Unlock()
+
+ // Flush local state to the remote filesystem.
+ if err := fs.Sync(ctx); err != nil {
+ return err
+ }
+
+ fs.savedDentryRW = make(map[*dentry]savedDentryRW)
+ return fs.root.prepareSaveRecursive(ctx)
+}
+
+// Preconditions:
+// * fd represents a pipe.
+// * fd is readable.
+func (fd *specialFileFD) savePipeData(ctx context.Context) error {
+ fd.bufMu.Lock()
+ defer fd.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0))
+ if n != 0 {
+ fd.buf = append(fd.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syserror.EAGAIN {
+ break
+ }
+ return err
+ }
+ }
+ if len(fd.buf) != 0 {
+ atomic.StoreUint32(&fd.haveBuf, 1)
+ }
+ return nil
+}
+
+func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
+ if d.isRegularFile() && !d.cachedMetadataAuthoritative() {
+ // Get updated metadata for d in case we need to perform metadata
+ // validation during restore.
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
+ }
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ for _, child := range d.children {
+ if child != nil {
+ if err := child.prepareSaveRecursive(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// beforeSave is invoked by stateify.
+func (d *dentry) beforeSave() {
+ if d.vfsd.IsDead() {
+ panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d)))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentry) afterLoad() {
+ d.hostFD = -1
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentryPlatformFile) afterLoad() {
+ if d.hostFileMapper.IsInited() {
+ // Ensure that we don't call d.hostFileMapper.Init() again.
+ d.hostFileMapperInitOnce.Do(func() {})
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (fd *specialFileFD) afterLoad() {
+ fd.handle.fd = -1
+}
+
+// CompleteRestore implements
+// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore.
+func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error {
+ fdmapv := ctx.Value(CtxRestoreServerFDMap)
+ if fdmapv == nil {
+ return fmt.Errorf("no server FD map available")
+ }
+ fdmap := fdmapv.(map[string]int)
+ fd, ok := fdmap[fs.iopts.UniqueID]
+ if !ok {
+ return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
+ }
+ fs.opts.fd = fd
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+ fs.inoByQIDPath = make(map[uint64]uint64)
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
+
+ // Restore remaining dentries.
+ if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil {
+ return err
+ }
+
+ // Re-open handles for specialFileFDs. Unlike the initial open
+ // (dentry.openSpecialFile()), pipes are always opened without blocking;
+ // non-readable pipe FDs are opened last to ensure that they don't get
+ // ENXIO if another specialFileFD represents the read end of the same pipe.
+ // This is consistent with VFS1.
+ haveWriteOnlyPipes := false
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ haveWriteOnlyPipes = true
+ continue
+ }
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ if haveWriteOnlyPipes {
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ // Discard state only required during restore.
+ fs.savedDentryRW = nil
+
+ return nil
+}
+
+func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error {
+ d.file = file
+
+ // Gofers do not preserve QID across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking QIDs.
+ //
+ // - We need to associate the new QID.Path with the existing d.ino.
+ d.qidPath = qid.Path
+ d.fs.inoMu.Lock()
+ d.fs.inoByQIDPath[qid.Path] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if !attrMask.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ }
+ if d.size != attr.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if !attrMask.MTime {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ }
+ if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromP9AttrsLocked(attrMask, attr)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Preconditions: d is not synthetic.
+func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ for _, child := range d.children {
+ if child == nil {
+ continue
+ }
+ if _, ok := d.fs.syncableDentries[child]; !ok {
+ // child is synthetic.
+ continue
+ }
+ if err := child.restoreRecursive(ctx, opts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Preconditions: d is not synthetic (but note that since this function
+// restores d.file, d.file.isNil() is always true at this point, so this can
+// only be detected by checking filesystem.syncableDentries). d.parent has been
+// restored.
+func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
+ return d.restoreDescendantsRecursive(ctx, opts)
+}
+
+func (fd *specialFileFD) completeRestore(ctx context.Context) error {
+ d := fd.dentry()
+ h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ if err != nil {
+ return err
+ }
+ fd.handle = h
+
+ ftype := d.fileType()
+ fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0
+ if fd.haveQueue {
+ if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index 326b940a7..a21199eac 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -42,9 +42,6 @@ type endpoint struct {
// dentry is the filesystem dentry which produced this endpoint.
dentry *dentry
- // file is the p9 file that contains a single unopened fid.
- file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
-
// path is the sentry path where this endpoint is bound.
path string
}
@@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
}
func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.file.Connect(flags)
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
return nil, syserr.ErrConnectionRefused
}
@@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla
c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
return nil, serr
}
return c, nil
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 71581736c..625400c0b 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -15,7 +15,6 @@
package gofer
import (
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -40,7 +40,7 @@ type specialFileFD struct {
fileDescription
// handle is used for file I/O. handle is immutable.
- handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ handle handle `state:"nosave"`
// isRegularFile is true if this FD represents a regular file which is only
// possible when filesystemOptions.regularFilesUseSpecialFileFD is in
@@ -54,12 +54,20 @@ type specialFileFD struct {
// haveQueue is true if this file description represents a file for which
// queue may send I/O readiness events. haveQueue is immutable.
- haveQueue bool
+ haveQueue bool `state:"nosave"`
queue waiter.Queue
// If seekable is true, off is the file offset. off is protected by mu.
mu sync.Mutex `state:"nosave"`
off int64
+
+ // If haveBuf is non-zero, this FD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to specialFileFD.savePipeData().
+ // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic
+ // memory operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
}
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
@@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
}
return nil, err
}
+ d.fs.syncMu.Lock()
+ d.fs.specialFileFDs[fd] = struct{}{}
+ d.fs.syncMu.Unlock()
return fd, nil
}
@@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
return 0, syserror.EOPNOTSUPP
}
- // Going through dst.CopyOutFrom() holds MM locks around file operations of
- // unknown duration. For regularFileFD, doing so is necessary to support
- // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
- // hold here since specialFileFD doesn't client-cache data. Just buffer the
- // read instead.
if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
+
+ bufN := int64(0)
+ if atomic.LoadUint32(&fd.haveBuf) != 0 {
+ var err error
+ fd.bufMu.Lock()
+ if len(fd.buf) != 0 {
+ var n int
+ n, err = dst.CopyOut(ctx, fd.buf)
+ dst = dst.DropFirst(n)
+ fd.buf = fd.buf[n:]
+ if len(fd.buf) == 0 {
+ atomic.StoreUint32(&fd.haveBuf, 0)
+ fd.buf = nil
+ }
+ bufN = int64(n)
+ if offset >= 0 {
+ offset += bufN
+ }
+ }
+ fd.bufMu.Unlock()
+ if err != nil {
+ return bufN, err
+ }
+ }
+
+ // Going through dst.CopyOutFrom() would hold MM locks around file
+ // operations of unknown duration. For regularFileFD, doing so is necessary
+ // to support mmap due to lock ordering; MM locks precede dentry.dataMu.
+ // That doesn't hold here since specialFileFD doesn't client-cache data.
+ // Just buffer the read instead.
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
if n == 0 {
- return 0, err
+ return bufN, err
}
if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
- return int64(cp), cperr
+ return bufN + int64(cp), cperr
}
- return int64(n), err
+ return bufN + int64(n), err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
d := fd.dentry()
- // If the regular file fd was opened with O_APPEND, make sure the file size
- // is updated. There is a possible race here if size is modified externally
- // after metadata cache is updated.
- if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
- if err := d.updateFromGetattr(ctx); err != nil {
- return 0, offset, err
+ if fd.isRegularFile {
+ // If the regular file fd was opened with O_APPEND, make sure the file
+ // size is updated. There is a possible race here if size is modified
+ // externally after metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
- }
- if fd.isRegularFile {
// We need to hold the metadataMu *while* writing to a regular file.
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
@@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- // If we have a host FD, fsyncing it is likely to be faster than an fsync
- // RPC.
- if fd.handle.fd >= 0 {
- ctx.UninterruptibleSleepStart(false)
- err := syscall.Fsync(int(fd.handle.fd))
- ctx.UninterruptibleSleepFinish(false)
- return err
+ return fd.sync(ctx, false /* forFilesystemSync */)
+}
+
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+ err := func() error {
+ // If we have a host FD, fsyncing it is likely to be faster than an fsync
+ // RPC.
+ if fd.handle.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(fd.handle.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ return fd.handle.file.fsync(ctx)
+ }()
+ if err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (fd represents a regular file that was opened for writing).
+ if fd.isRegularFile && fd.vfsfd.IsWritable() {
+ return err
+ }
+ ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err)
}
- return fd.handle.file.fsync(ctx)
+ return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 7e825caae..9cbe805b9 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,7 +17,6 @@ package gofer
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
-func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
- return ts.Sec*1e9 + int64(ts.Nsec)
-}
-
-func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
- return linux.StatxTimestamp{
- Sec: ns / 1e9,
- Nsec: uint32(ns % 1e9),
- }
-}
-
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 56bcf9bdb..4ae9d6d5e 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "inode_refs.go",
package = "host",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "connected_endpoint_refs.go",
package = "host",
prefix = "ConnectedEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ConnectedEndpoint",
},
@@ -33,7 +33,7 @@ go_library(
"host.go",
"inode_refs.go",
"ioctl_unsafe.go",
- "mmap.go",
+ "save_restore.go",
"socket.go",
"socket_iovec.go",
"socket_unsafe.go",
@@ -51,6 +51,7 @@ go_library(
"//pkg/log",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
index 0135e4428..13ef48cb5 100644
--- a/pkg/sentry/fsimpl/host/control.go
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription {
}
// Create the file backed by hostFD.
- file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */)
+ file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{})
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 698e913fe..39b902a3e 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -19,6 +19,7 @@ package host
import (
"fmt"
"math"
+ "sync/atomic"
"syscall"
"golang.org/x/sys/unix"
@@ -40,34 +41,97 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) {
- // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
- // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
- // devices.
+// inode implements kernfs.Inode.
+//
+// +stateify savable
+type inode struct {
+ kernfs.InodeNoStatFS
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.CachedMappable
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
+
+ locks vfs.FileLocks
+
+ // When the reference count reaches zero, the host fd is closed.
+ inodeRefs
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // ftype is the file's type (a linux.S_IFMT mask).
+ //
+ // This field is initialized at creation time and is immutable.
+ ftype uint16
+
+ // mayBlock is true if hostFD is non-blocking, and operations on it may
+ // return EAGAIN or EWOULDBLOCK instead of blocking.
+ //
+ // This field is initialized at creation time and is immutable.
+ mayBlock bool
+
+ // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file
+ // offsets are meaningful iff seekable is true.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // savable is true if hostFD may be saved/restored by its numeric value.
+ //
+ // This field is initialized at creation time and is immutable.
+ savable bool
+
+ // Event queue for blocking operations.
+ queue waiter.Queue
+
+ // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to inode.beforeSave(). haveBuf
+ // and buf are protected by bufMu. haveBuf is accessed using atomic memory
+ // operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
+}
+
+func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) {
+ // Determine if hostFD is seekable.
_, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
seekable := err != syserror.ESPIPE
+ // We expect regular files to be seekable, as this is required for them to
+ // be memory-mappable.
+ if !seekable && fileType == syscall.S_IFREG {
+ ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD)
+ return nil, syserror.ESPIPE
+ }
i := &inode{
- hostFD: hostFD,
- ino: fs.NextIno(),
- isTTY: isTTY,
- wouldBlock: wouldBlock(uint32(fileType)),
- seekable: seekable,
- // NOTE(b/38213152): Technically, some obscure char devices can be memory
- // mapped, but we only allow regular files.
- canMap: fileType == linux.S_IFREG,
- }
- i.pf.inode = i
+ hostFD: hostFD,
+ ino: fs.NextIno(),
+ ftype: uint16(fileType),
+ mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR,
+ seekable: seekable,
+ isTTY: isTTY,
+ savable: savable,
+ }
+ i.CachedMappable.Init(hostFD)
i.EnableLeakCheck()
- // Non-seekable files can't be memory mapped, assert this.
- if !i.seekable && i.canMap {
- panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
- }
-
- // If the hostFD would block, we must set it to non-blocking and handle
- // blocking behavior in the sentry.
- if i.wouldBlock {
+ // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and
+ // handle blocking behavior in the sentry.
+ if i.mayBlock {
if err := syscall.SetNonblock(i.hostFD, true); err != nil {
return nil, err
}
@@ -80,6 +144,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (
// NewFDOptions contains options to NewFD.
type NewFDOptions struct {
+ // If Savable is true, the host file descriptor may be saved/restored by
+ // numeric value; the sandbox API requires a corresponding host FD with the
+ // same numeric value to be provieded at time of restore.
+ Savable bool
+
// If IsTTY is true, the file descriptor is a TTY.
IsTTY bool
@@ -114,7 +183,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
}
d := &kernfs.Dentry{}
- i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
+ i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
if err != nil {
return nil, err
}
@@ -132,7 +201,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
return NewFD(ctx, mnt, hostFD, &NewFDOptions{
- IsTTY: isTTY,
+ Savable: true,
+ IsTTY: isTTY,
})
}
@@ -191,68 +261,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
-// inode implements kernfs.Inode.
-//
-// +stateify savable
-type inode struct {
- kernfs.InodeNoStatFS
- kernfs.InodeNotDirectory
- kernfs.InodeNotSymlink
- kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
-
- locks vfs.FileLocks
-
- // When the reference count reaches zero, the host fd is closed.
- inodeRefs
-
- // hostFD contains the host fd that this file was originally created from,
- // which must be available at time of restore.
- //
- // This field is initialized at creation time and is immutable.
- hostFD int
-
- // ino is an inode number unique within this filesystem.
- //
- // This field is initialized at creation time and is immutable.
- ino uint64
-
- // isTTY is true if this file represents a TTY.
- //
- // This field is initialized at creation time and is immutable.
- isTTY bool
-
- // seekable is false if the host fd points to a file representing a stream,
- // e.g. a socket or a pipe. Such files are not seekable and can return
- // EWOULDBLOCK for I/O operations.
- //
- // This field is initialized at creation time and is immutable.
- seekable bool
-
- // wouldBlock is true if the host FD would return EWOULDBLOCK for
- // operations that would block.
- //
- // This field is initialized at creation time and is immutable.
- wouldBlock bool
-
- // Event queue for blocking operations.
- queue waiter.Queue
-
- // canMap specifies whether we allow the file to be memory mapped.
- //
- // This field is initialized at creation time and is immutable.
- canMap bool
-
- // mapsMu protects mappings.
- mapsMu sync.Mutex `state:"nosave"`
-
- // If canMap is true, mappings tracks mappings of hostFD into
- // memmap.MappingSpaces.
- mappings memmap.MappingSet
-
- // pf implements platform.File for mappings of hostFD.
- pf inodePlatformFile
-}
-
// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s syscall.Stat_t
@@ -422,14 +430,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
oldpgend, _ := usermem.PageRoundUp(oldSize)
newpgend, _ := usermem.PageRoundUp(s.Size)
if oldpgend != newpgend {
- i.mapsMu.Lock()
- i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
- // Compare Linux's mm/truncate.c:truncate_setsize() =>
- // truncate_pagecache() =>
- // mm/memory.c:unmap_mapping_range(evencows=1).
- InvalidatePrivate: true,
- })
- i.mapsMu.Unlock()
+ i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend})
}
}
}
@@ -448,7 +449,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
// DecRef implements kernfs.Inode.DecRef.
func (i *inode) DecRef(ctx context.Context) {
i.inodeRefs.DecRef(func() {
- if i.wouldBlock {
+ if i.mayBlock {
fdnotifier.RemoveFD(int32(i.hostFD))
}
if err := unix.Close(i.hostFD); err != nil {
@@ -567,6 +568,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin
// PRead implements vfs.FileDescriptionImpl.PRead.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
return 0, syserror.ESPIPE
@@ -577,19 +585,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off
// Read implements vfs.FileDescriptionImpl.Read.
func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
+ bufN, err := i.readFromBuf(ctx, &dst)
+ if err != nil {
+ return bufN, err
+ }
n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ total := bufN + n
if isBlockError(err) {
// If we got any data at all, return it as a "completed" partial read
// rather than retrying until complete.
- if n != 0 {
+ if total != 0 {
err = nil
} else {
err = syserror.ErrWouldBlock
}
}
- return n, err
+ return total, err
}
f.offsetMu.Lock()
@@ -599,13 +619,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
return n, err
}
-func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
- // Check that flags are supported.
- //
- // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
- if flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) {
+ if atomic.LoadUint32(&i.haveBuf) == 0 {
+ return 0, nil
+ }
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ if len(i.buf) == 0 {
+ return 0, nil
}
+ n, err := dst.CopyOut(ctx, i.buf)
+ *dst = dst.DropFirst(n)
+ i.buf = i.buf[n:]
+ if len(i.buf) == 0 {
+ atomic.StoreUint32(&i.haveBuf, 0)
+ i.buf = nil
+ }
+ return int64(n), err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := dst.CopyOutFrom(ctx, reader)
hostfd.PutReadWriterAt(reader)
@@ -735,31 +768,37 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
}
// Sync implements vfs.FileDescriptionImpl.Sync.
-func (f *fileDescription) Sync(context.Context) error {
+func (f *fileDescription) Sync(ctx context.Context) error {
// TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
return unix.Fsync(f.inode.hostFD)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
- if !f.inode.canMap {
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ if f.inode.ftype != syscall.S_IFREG {
return syserror.ENODEV
}
i := f.inode
- i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+ i.CachedMappable.InitFileMapperOnce()
return vfs.GenericConfigureMMap(&f.vfsfd, i, opts)
}
// EventRegister implements waiter.Waitable.EventRegister.
func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
f.inode.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (f *fileDescription) EventUnregister(e *waiter.Entry) {
f.inode.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// Readiness uses the poll() syscall to check the status of the underlying FD.
diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go
new file mode 100644
index 000000000..8800652a9
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/save_restore.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// beforeSave is invoked by stateify.
+func (i *inode) beforeSave() {
+ if !i.savable {
+ panic("host.inode is not savable")
+ }
+ if i.ftype == syscall.S_IFIFO {
+ // If this pipe FD is readable, drain it so that bytes in the pipe can
+ // be read after restore. (This is a legacy VFS1 feature.) We don't
+ // know if the pipe FD is readable, so just try reading and tolerate
+ // EBADF from the read.
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */)
+ if n != 0 {
+ i.buf = append(i.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF {
+ break
+ }
+ panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err))
+ }
+ }
+ if len(i.buf) != 0 {
+ atomic.StoreUint32(&i.haveBuf, 1)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inode) afterLoad() {
+ if i.mayBlock {
+ if err := syscall.SetNonblock(i.hostFD, true); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err))
+ }
+ if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 412bdb2eb..b2f43a119 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
}
-// wouldBlock returns true for file types that can return EWOULDBLOCK
-// for blocking operations, e.g. pipes, character devices, and sockets.
-func wouldBlock(fileType uint32) bool {
- return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
-}
-
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
// If so, they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 858cc24ce..6dbc7e34d 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "kernfs",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Dentry",
+ "Linker": "*Dentry",
+ },
+)
+
+go_template_instance(
name = "fstree",
out = "fstree.go",
package = "kernfs",
@@ -27,22 +39,11 @@ go_template_instance(
)
go_template_instance(
- name = "dentry_refs",
- out = "dentry_refs.go",
- package = "kernfs",
- prefix = "Dentry",
- template = "//pkg/refs_vfs2:refs_template",
- types = {
- "T": "Dentry",
- },
-)
-
-go_template_instance(
name = "static_directory_refs",
out = "static_directory_refs.go",
package = "kernfs",
prefix = "StaticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "StaticDirectory",
},
@@ -53,7 +54,7 @@ go_template_instance(
out = "dir_refs.go",
package = "kernfs_test",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -64,7 +65,7 @@ go_template_instance(
out = "readonly_dir_refs.go",
package = "kernfs_test",
prefix = "readonlyDir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "readonlyDir",
},
@@ -75,7 +76,7 @@ go_template_instance(
out = "synthetic_directory_refs.go",
package = "kernfs",
prefix = "syntheticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "syntheticDirectory",
},
@@ -84,13 +85,15 @@ go_template_instance(
go_library(
name = "kernfs",
srcs = [
- "dentry_refs.go",
+ "dentry_list.go",
"dynamic_bytes_file.go",
"fd_impl_util.go",
"filesystem.go",
"fstree.go",
"inode_impl_util.go",
"kernfs.go",
+ "mmap_util.go",
+ "save_restore.go",
"slot_list.go",
"static_directory_refs.go",
"symlink.go",
@@ -104,8 +107,12 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
@@ -129,6 +136,7 @@ go_test(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index b929118b1..485504995 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -47,11 +47,11 @@ type DynamicBytesFile struct {
var _ Inode = (*DynamicBytesFile)(nil)
// Init initializes a dynamic bytes file.
-func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
f.data = data
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index abf1905d6..f8dae22f8 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
return fd.vfsfd.VirtualDentry().Mount().Filesystem()
}
+func (fd *GenericDirectoryFD) dentry() *Dentry {
+ return fd.vfsfd.Dentry().Impl().(*Dentry)
+}
+
func (fd *GenericDirectoryFD) inode() Inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return fd.dentry().inode
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
@@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// Handle "..".
if fd.off == 1 {
- vfsd := fd.vfsfd.VirtualDentry().Dentry()
- parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
+ parentInode := genericParentOrSelf(fd.dentry()).inode
stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
@@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
var err error
relOffset := fd.off - int64(len(fd.children.set)) - 2
- fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset)
return err
}
@@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
- inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(ctx, fd.filesystem(), creds, opts)
+ return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts)
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 6426a55f6..e77523f22 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
// * isDir(parentInode) == true.
-func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) {
- if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
- return "", err
+func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error {
+ if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
}
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return "", syserror.EEXIST
+ if name == "." || name == ".." {
+ return syserror.EEXIST
}
- if len(pc) > linux.NAME_MAX {
- return "", syserror.ENAMETOOLONG
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
}
- if _, ok := parent.children[pc]; ok {
- return "", syserror.EEXIST
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
}
if parent.VFSDentry().IsDead() {
- return "", syserror.ENOENT
+ return syserror.ENOENT
}
- return pc, nil
+ return nil
}
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
@@ -245,7 +244,41 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *Filesystem) Release(context.Context) {
+func (fs *Filesystem) Release(ctx context.Context) {
+ root := fs.root
+ if root == nil {
+ return
+ }
+ fs.mu.Lock()
+ root.releaseKeptDentriesLocked(ctx)
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+ fs.mu.Unlock()
+ // Drop ref acquired in Dentry.InitRoot().
+ root.DecRef(ctx)
+}
+
+// releaseKeptDentriesLocked recursively drops all dentry references created by
+// Lookup when Dentry.inode.Keep() is true.
+//
+// Precondition: Filesystem.mu is held.
+func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) {
+ if d.inode.Keep() && d != d.fs.root {
+ d.decRefLocked(ctx)
+ }
+
+ if d.isDir() {
+ var children []*Dentry
+ d.dirMu.Lock()
+ for _, child := range d.children {
+ children = append(children, child)
+ }
+ d.dirMu.Unlock()
+ for _, child := range children {
+ child.releaseKeptDentriesLocked(ctx)
+ }
+ }
}
// Sync implements vfs.FilesystemImpl.Sync.
@@ -318,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if rp.Mount() != vd.Mount() {
return syserror.EXDEV
}
@@ -360,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -373,7 +409,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
- childI = newSyntheticDirectory(rp.Credentials(), opts.Mode)
+ childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode)
}
var child Dentry
child.Init(fs, childI)
@@ -396,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
@@ -517,9 +556,6 @@ afterTrailingSymlink:
}
var child Dentry
child.Init(fs, childI)
- // FIXME(gvisor.dev/issue/1193): Race between checking existence with
- // fs.stepExistingLocked and parent.insertChild. If possible, we should hold
- // dirMu from one to the other.
parent.insertChild(pc, &child)
// Open may block so we need to unlock fs.mu. IncRef child to prevent
// its destruction while fs.mu is unlocked.
@@ -626,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Can we create the dst dentry?
var dst *Dentry
- pc, err := checkCreateLocked(ctx, rp, dstDir)
- switch err {
+ pc := rp.Component()
+ switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err {
case nil:
// Ok, continue with rename as replacement.
case syserror.EEXIST:
@@ -791,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 122b10591..d83c17f83 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -21,9 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// InodeNoopRefCount partially implements the Inode interface, specifically the
@@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error)
}
// IterDirents implements Inode.IterDirents.
-func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
panic("IterDirents called on non-directory inode")
}
@@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry,
//
// +stateify savable
type InodeAttrs struct {
- devMajor uint32
- devMinor uint32
- ino uint64
- mode uint32
- uid uint32
- gid uint32
- nlink uint32
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+ blockSize uint32
+
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
}
// Init initializes this InodeAttrs.
-func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
+func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
if mode.FileType() == 0 {
panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
}
@@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in
atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
atomic.StoreUint32(&a.nlink, nlink)
+ atomic.StoreUint32(&a.blockSize, usermem.PageSize)
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.atime, now)
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
}
// DevMajor returns the device major number.
@@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode {
return linux.FileMode(atomic.LoadUint32(&a.mode))
}
+// TouchAtime updates a.atime to the current time.
+func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) {
+ if mnt.Flags.NoATime || mnt.ReadOnly() {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds())
+ mnt.EndWrite()
+}
+
+// TouchCMtime updates a.{c/m}time to the current time. The caller should
+// synchronize calls to this so that ctime and mtime are updated to the same
+// value.
+func (a *InodeAttrs) TouchCMtime(ctx context.Context) {
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
+}
+
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME
stat.DevMajor = a.devMajor
stat.DevMinor = a.devMinor
stat.Ino = atomic.LoadUint64(&a.ino)
@@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li
stat.UID = atomic.LoadUint32(&a.uid)
stat.GID = atomic.LoadUint32(&a.gid)
stat.Nlink = atomic.LoadUint32(&a.nlink)
-
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
-
+ stat.Blksize = atomic.LoadUint32(&a.blockSize)
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime))
return stat, nil
}
// SetStat implements Inode.SetStat.
func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- return a.SetInodeStat(ctx, fs, creds, opts)
-}
-
-// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
-// This function can be used by other kernfs-based filesystem implementation to
-// sets the unexported attributes into InodeAttrs.
-func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
}
@@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
// inode numbers are immutable after node creation. Setting the size is often
// allowed by kernfs files but does not do anything. If some other behavior is
// needed, the embedder should consider extending SetStat.
- //
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
@@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
atomic.StoreUint32(&a.gid, stat.GID)
}
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ stat.Atime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.atime, stat.Atime.ToNsec())
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec())
+ }
+
return nil
}
@@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error
}
// IterDirents implements Inode.IterDirents.
-func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
// All entries from OrderedChildren have already been handled in
// GenericDirectoryFD.IterDirents.
return offset, nil
@@ -528,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e
return o.Unlink(ctx, name, child)
}
-// +stateify savable
-type renameAcrossDifferentImplementationsError struct{}
-
-func (renameAcrossDifferentImplementationsError) Error() string {
- return "rename across inodes with different implementations"
-}
-
// Rename implements Inode.Rename.
//
// Precondition: Rename may only be called across two directory inodes with
@@ -545,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string {
//
// Postcondition: reference on any replaced dentry transferred to caller.
func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+
dst, ok := dstDir.(interface{}).(*OrderedChildren)
if !ok {
- return renameAcrossDifferentImplementationsError{}
+ return syserror.EXDEV
}
- if !o.writable || !dst.writable {
+ if !dst.writable {
return syserror.EPERM
}
+
// Note: There's a potential deadlock below if concurrent calls to Rename
// refer to the same src and dst directories in reverse. We avoid any
// ordering issues because the caller is required to serialize concurrent
@@ -619,9 +657,9 @@ type StaticDirectory struct {
var _ Inode = (*StaticDirectory)(nil)
// NewStaticDir creates a new static directory and returns its dentry.
-func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
+func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
- inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(OrderedChildrenOptions{})
@@ -632,12 +670,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64
}
// Init initializes StaticDirectory.
-func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
+func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
s.fdOpts = fdOpts
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
// Open implements Inode.Open.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 606081e68..abb477c7d 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -107,6 +108,23 @@ type Filesystem struct {
// nextInoMinusOne is used to to allocate inode numbers on this
// filesystem. Must be accessed by atomic operations.
nextInoMinusOne uint64
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by mu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // MaxCachedDentries is the maximum size of cachedDentries. If not set,
+ // defaults to 0 and kernfs does not cache any dentries. This is immutable.
+ MaxCachedDentries uint64
+
+ // root is the root dentry of this filesystem. Note that root may be nil for
+ // filesystems on a disconnected mount without a root (e.g. pipefs, sockfs,
+ // hostfs). Filesystem holds an extra reference on root to prevent it from
+ // being destroyed prematurely. This is immutable.
+ root *Dentry
}
// deferDecRef defers dropping a dentry ref until the next call to
@@ -165,7 +183,12 @@ const (
// +stateify savable
type Dentry struct {
vfsd vfs.Dentry
- DentryRefs
+
+ // refs is the reference count. When refs reaches 0, the dentry may be
+ // added to the cache or destroyed. If refs == -1, the dentry has already
+ // been destroyed. refs are allowed to go to 0 and increase again. refs is
+ // accessed using atomic memory operations.
+ refs int64
// fs is the owning filesystem. fs is immutable.
fs *Filesystem
@@ -177,6 +200,12 @@ type Dentry struct {
parent *Dentry
name string
+ // If cached is true, dentryEntry links dentry into
+ // Filesystem.cachedDentries. cached and dentryEntry are protected by
+ // Filesystem.mu.
+ cached bool
+ dentryEntry
+
// dirMu protects children and the names of child Dentries.
//
// Note that holding fs.mu for writing is not sufficient;
@@ -188,6 +217,201 @@ type Dentry struct {
inode Inode
}
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *Dentry) IncRef() {
+ // d.refs may be 0 if d.fs.mu is locked, which serializes against
+ // d.cacheLocked().
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *Dentry) TryIncRef() bool {
+ for {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.fs.mu.Lock()
+ d.cacheLocked(ctx)
+ d.fs.mu.Unlock()
+ } else if r < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+}
+
+func (d *Dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.cacheLocked(ctx)
+ } else if r < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+}
+
+// cacheLocked should be called after d's reference count becomes 0. The ref
+// count check may happen before acquiring d.fs.mu so there might be a race
+// condition where the ref count is increased again by the time the caller
+// acquires d.fs.mu. This race is handled.
+// Only reachable dentries are added to the cache. However, a dentry might
+// become unreachable *while* it is in the cache due to invalidation.
+//
+// Preconditions: d.fs.mu must be locked for writing.
+func (d *Dentry) cacheLocked(ctx context.Context) {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires d.fs.mu, so if d.refs is zero then it will
+ // remain zero while we hold d.fs.mu for writing.)
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d))
+ }
+ if refs > 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ // If the dentry is deleted and invalidated or has no parent, then it is no
+ // longer reachable by path resolution and should be dropped immediately
+ // because it has zero references.
+ // Note that a dentry may not always have a parent; for example magic links
+ // as described in Inode.Getlink.
+ if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil {
+ if !isDead {
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
+ }
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked(ctx)
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries {
+ return
+ }
+ d.fs.evictCachedDentryLocked(ctx)
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.opts.maxCachedDentries, so we don't loop.
+}
+
+// Preconditions:
+// * fs.mu must be locked for writing.
+// * fs.cachedDentriesLen != 0.
+func (fs *Filesystem) evictCachedDentryLocked(ctx context.Context) {
+ // Evict the least recently used dentry because cache size is greater than
+ // max cache size (configured on mount).
+ victim := fs.cachedDentries.Back()
+ fs.cachedDentries.Remove(victim)
+ fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // after it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if !victim.vfsd.IsDead() {
+ victim.parent.dirMu.Lock()
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry())
+ delete(victim.parent.children, victim.name)
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.MaxCachedDentries, so we don't loop.
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions:
+// * d.fs.mu must be locked for writing.
+// * d.refs == 0.
+// * d should have been removed from d.parent.children, i.e. d is not reachable
+// by path traversal.
+// * d.vfsd.IsDead() is true.
+func (d *Dentry) destroyLocked(ctx context.Context) {
+ refs := atomic.LoadInt64(&d.refs)
+ switch refs {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
+ d.inode.DecRef(ctx) // IncRef from Init.
+ d.inode = nil
+
+ if d.parent != nil {
+ d.parent.decRefLocked(ctx)
+ }
+
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *Dentry) RefType() string {
+ return "kernfs.Dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *Dentry) LeakMessage() string {
+ return fmt.Sprintf("[kernfs.Dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *Dentry) LogRefs() bool {
+ return false
+}
+
+// InitRoot initializes this dentry as the root of the filesystem.
+//
+// Precondition: Caller must hold a reference on inode.
+//
+// Postcondition: Caller's reference on inode is transferred to the dentry.
+func (d *Dentry) InitRoot(fs *Filesystem, inode Inode) {
+ d.Init(fs, inode)
+ fs.root = d
+ // Hold an extra reference on the root dentry. It is held by fs to prevent the
+ // root from being "cached" and subsequently evicted.
+ d.IncRef()
+}
+
// Init initializes this dentry.
//
// Precondition: Caller must hold a reference on inode.
@@ -197,6 +421,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
d.vfsd.Init(d)
d.fs = fs
d.inode = inode
+ atomic.StoreInt64(&d.refs, 1)
ftype := inode.Mode().FileType()
if ftype == linux.ModeDirectory {
d.flags |= dflagsIsDir
@@ -204,7 +429,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
if ftype == linux.ModeSymlink {
d.flags |= dflagsIsSymlink
}
- d.EnableLeakCheck()
+ refsvfs2.Register(d)
}
// VFSDentry returns the generic vfs dentry for this kernfs dentry.
@@ -222,32 +447,6 @@ func (d *Dentry) isSymlink() bool {
return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
}
-// DecRef implements vfs.DentryImpl.DecRef.
-func (d *Dentry) DecRef(ctx context.Context) {
- decRefParent := false
- d.fs.mu.Lock()
- d.DentryRefs.DecRef(func() {
- d.inode.DecRef(ctx) // IncRef from Init.
- d.inode = nil
- if d.parent != nil {
- // We will DecRef d.parent once all locks are dropped.
- decRefParent = true
- d.parent.dirMu.Lock()
- // Remove d from parent.children. It might already have been
- // removed due to invalidation.
- if _, ok := d.parent.children[d.name]; ok {
- delete(d.parent.children, d.name)
- d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
- }
- d.parent.dirMu.Unlock()
- }
- })
- d.fs.mu.Unlock()
- if decRefParent {
- d.parent.DecRef(ctx) // IncRef from Dentry.insertChild.
- }
-}
-
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
//
// Although Linux technically supports inotify on pseudo filesystems (inotify
@@ -267,7 +466,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {}
// this dentry. This does not update the directory inode, so calling this on its
// own isn't sufficient to insert a child into a directory.
//
-// Precondition: d must represent a directory inode.
+// Preconditions:
+// * d must represent a directory inode.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChild(name string, child *Dentry) {
d.dirMu.Lock()
d.insertChildLocked(name, child)
@@ -280,6 +481,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) {
// Preconditions:
// * d must represent a directory inode.
// * d.dirMu must be locked.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChildLocked(name string, child *Dentry) {
if !d.isDir() {
panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d))
@@ -436,7 +638,7 @@ type inodeDirectory interface {
// the inode is a directory.
//
// The child returned by Lookup will be hashed into the VFS dentry tree,
- // atleast for the duration of the current FS operation.
+ // at least for the duration of the current FS operation.
//
// Lookup must return the child with an extra reference whose ownership is
// transferred to the dentry that is created to point to that inode. If
@@ -454,7 +656,7 @@ type inodeDirectory interface {
// inside the entries returned by this IterDirents invocation. In other words,
// 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
// the return value, while 'relOffset' is the place to start iteration.
- IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
+ IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
type inodeSymlink interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 82fa19c03..2418eec44 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file."
// RootDentryFn is a generator function for creating the root dentry of a test
// filesystem. See newTestSystem.
-type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode
+type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode
// newTestSystem sets up a minimal environment for running a test, including an
// instance of a test filesystem. Tests can control the contents of the
@@ -72,10 +72,10 @@ type file struct {
content string
}
-func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode {
+func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode {
f := &file{}
f.content = content
- f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
+ f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
return f
}
@@ -105,9 +105,9 @@ type readonlyDir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &readonlyDir{}
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
dir.EnableLeakCheck()
dir.IncLinks(dir.OrderedChildren.Populate(contents))
@@ -142,10 +142,10 @@ type dir struct {
fs *filesystem
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &dir{}
dir.fs = fs
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
dir.EnableLeakCheck()
@@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) {
func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- dir := d.fs.newDir(creds, opts.Mode, nil)
+ dir := d.fs.newDir(ctx, creds, opts.Mode, nil)
if err := d.OrderedChildren.Insert(name, dir); err != nil {
dir.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
d.IncLinks(1)
return dir, nil
}
func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- f := d.fs.newFile(creds, "")
+ f := d.fs.newFile(ctx, creds, "")
if err := d.OrderedChildren.Insert(name, f); err != nil {
f.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
return f, nil
}
@@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {}
func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
fs.VFSFilesystem().Init(vfsObj, &fst, fs)
- root := fst.rootFn(creds, fs)
+ root := fst.rootFn(ctx, creds, fs)
var d kernfs.Dentry
d.Init(&fs.Filesystem, root)
return fs.VFSFilesystem(), d.VFSDentry(), nil
@@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst
// -------------------- Remainder of the file are test cases --------------------
func TestBasic(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -228,9 +230,9 @@ func TestBasic(t *testing.T) {
}
func TestMkdirGetDentry(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) {
}
func TestReadStaticFile(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) {
}
func TestCreateNewFileInStaticDir(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
func TestDirFDReadWrite(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, nil)
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, nil)
})
defer sys.Destroy()
@@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) {
}
func TestDirFDIterDirents(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
// Fill root with nodes backed by various inode implementations.
- "dir1": fs.newReadonlyDir(creds, 0755, nil),
- "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{
- "dir3": fs.newDir(creds, 0755, nil),
+ "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil),
+ "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir3": fs.newDir(ctx, creds, 0755, nil),
}),
- "file1": fs.newFile(creds, staticFileContent),
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go
index b51a17bed..bd6a134b4 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package host
+package kernfs
import (
"gvisor.dev/gvisor/pkg/context"
@@ -26,11 +26,14 @@ import (
// inodePlatformFile implements memmap.File. It exists solely because inode
// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
-// inodePlatformFile should only be used if inode.canMap is true.
-//
// +stateify savable
type inodePlatformFile struct {
- *inode
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ // inodePlatformFile does not own hostFD and hence should not close it.
+ hostFD int
// fdRefsMu protects fdRefs.
fdRefsMu sync.Mutex `state:"nosave"`
@@ -43,12 +46,12 @@ type inodePlatformFile struct {
fileMapper fsutil.HostFileMapper
// fileMapperInitOnce is used to lazily initialize fileMapper.
- fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ fileMapperInitOnce sync.Once `state:"nosave"`
}
+var _ memmap.File = (*inodePlatformFile)(nil)
+
// IncRef implements memmap.File.IncRef.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.IncRefAndAccount(fr)
@@ -56,8 +59,6 @@ func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
}
// DecRef implements memmap.File.DecRef.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.DecRefAndAccount(fr)
@@ -65,8 +66,6 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
}
// MapInternal implements memmap.File.MapInternal.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
@@ -76,10 +75,32 @@ func (i *inodePlatformFile) FD() int {
return i.hostFD
}
-// AddMapping implements memmap.Mappable.AddMapping.
+// CachedMappable implements memmap.Mappable. This utility can be embedded in a
+// kernfs.Inode that represents a host file to make the inode mappable.
+// CachedMappable caches the mappings of the host file. CachedMappable must be
+// initialized (via Init) with a hostFD before use.
//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+// +stateify savable
+type CachedMappable struct {
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of hostFD into memmap.MappingSpaces.
+ mappings memmap.MappingSet
+
+ // pf implements memmap.File for mappings backed by a host fd.
+ pf inodePlatformFile
+}
+
+var _ memmap.Mappable = (*CachedMappable)(nil)
+
+// Init initializes i.pf. This must be called before using CachedMappable.
+func (i *CachedMappable) Init(hostFD int) {
+ i.pf.hostFD = hostFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
i.mapsMu.Lock()
mapped := i.mappings.AddMapping(ms, ar, offset, writable)
for _, r := range mapped {
@@ -90,9 +111,7 @@ func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar userm
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
i.mapsMu.Lock()
unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable)
for _, r := range unmapped {
@@ -102,16 +121,12 @@ func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar us
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
return i.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
mr := optional
return []memmap.Translation{
{
@@ -124,10 +139,26 @@ func (i *inode) Translate(ctx context.Context, required, optional memmap.Mappabl
}
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) InvalidateUnsavable(ctx context.Context) error {
+func (i *CachedMappable) InvalidateUnsavable(ctx context.Context) error {
// We expect the same host fd across save/restore, so all translations
// should be valid.
return nil
}
+
+// InvalidateRange invalidates the passed range on i.mappings.
+func (i *CachedMappable) InvalidateRange(r memmap.MappableRange) {
+ i.mapsMu.Lock()
+ i.mappings.Invalidate(r, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ i.mapsMu.Unlock()
+}
+
+// InitFileMapperOnce initializes the host file mapper. It ensures that the
+// file mapper is initialized just once.
+func (i *CachedMappable) InitFileMapperOnce() {
+ i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/save_restore.go b/pkg/sentry/fsimpl/kernfs/save_restore.go
new file mode 100644
index 000000000..f78509eb7
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/save_restore.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// afterLoad is invoked by stateify.
+func (d *Dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) >= 0 {
+ refsvfs2.Register(d)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodePlatformFile) afterLoad() {
+ if i.fileMapper.IsInited() {
+ // Ensure that we don't call i.fileMapper.Init() again.
+ i.fileMapperInitOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 934cc6c9e..a0736c0d6 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -38,16 +38,16 @@ type StaticSymlink struct {
var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
+func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
inode := &StaticSymlink{}
- inode.Init(creds, devMajor, devMinor, ino, target)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, target)
return inode
}
// Init initializes the instance.
-func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
+func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
s.target = target
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
}
// Readlink implements Inode.Readlink.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index d0ed17b18..463d77d79 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -41,17 +41,17 @@ type syntheticDirectory struct {
var _ Inode = (*syntheticDirectory)(nil)
-func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode {
+func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode {
inode := &syntheticDirectory{}
- inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
+ inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
return inode
}
-func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm))
}
- dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
+ dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
dir.OrderedChildren.Init(OrderedChildrenOptions{
Writable: true,
})
@@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs
if !opts.ForSyntheticMountpoint {
return nil, syserror.EPERM
}
- subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
if err := dir.OrderedChildren.Insert(name, subdirI); err != nil {
subdirI.DecRef(ctx)
return nil, err
}
+ dir.TouchCMtime(ctx)
return subdirI, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index 1e11b0428..bf13bbbf4 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -23,6 +23,7 @@ go_library(
"fstree.go",
"overlay.go",
"regular_file.go",
+ "save_restore.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
@@ -30,6 +31,8 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 4506642ca..469f3a33d 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er
if dirent.Name == "." || dirent.Name == ".." {
continue
}
- child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
+ child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 78a01bbb7..bc07d72c0 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
// * !rp.Done().
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) {
if !d.isDir() {
- return nil, syserror.ENOTDIR
+ return nil, lookupLayerNone, syserror.ENOTDIR
}
if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
afterSymlink:
name := rp.Component()
if name == "." {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if name == ".." {
if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
rp.Advance()
- return d.parent, nil
+ return d.parent, d.parent.topLookupLayer(), nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx)
if err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
goto afterSymlink // don't check the current directory again
}
rp.Advance()
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
-func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) {
if child, ok := parent.children[name]; ok {
- return child, nil
+ return child, child.topLookupLayer(), nil
}
- child, err := fs.lookupLocked(ctx, parent, name)
+ child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if parent.children == nil {
parent.children = make(map[string]*dentry)
@@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
parent.children[name] = child
// child's refcount is initially 0, so it may be dropped after traversal.
*ds = appendDentry(*ds, child)
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * parent.dirMu must be locked.
-func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) {
childPath := fspath.Parse(name)
child := fs.newDentry()
- existsOnAnyLayer := false
+ topLookupLayer := lookupLayerNone
var lookupErr error
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
defer childVD.DecRef(ctx)
mask := uint32(linux.STATX_TYPE)
- if !existsOnAnyLayer {
+ if topLookupLayer == lookupLayerNone {
// Mode, UID, GID, and (for non-directories) inode number come from
// the topmost layer on which the file exists.
mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
@@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if isWhiteout(&stat) {
// This is a whiteout, so it "doesn't exist" on this layer, and
// layers below this one are ignored.
+ if isUpper {
+ topLookupLayer = lookupLayerUpperWhiteout
+ }
return false
}
isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
- if existsOnAnyLayer && !isDir {
+ if topLookupLayer != lookupLayerNone && !isDir {
// Directories are not merged with non-directory files from lower
// layers; instead, layers including and below the first
// non-directory file are ignored. (This file must be a directory
@@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
} else {
child.lowerVDs = append(child.lowerVDs, childVD)
}
- if !existsOnAnyLayer {
- existsOnAnyLayer = true
+ if topLookupLayer == lookupLayerNone {
+ if isUpper {
+ topLookupLayer = lookupLayerUpper
+ } else {
+ topLookupLayer = lookupLayerLower
+ }
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
@@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if lookupErr != nil {
child.destroyLocked(ctx)
- return nil, lookupErr
+ return nil, topLookupLayer, lookupErr
}
- if !existsOnAnyLayer {
+ if !topLookupLayer.existsInOverlay() {
child.destroyLocked(ctx)
- return nil, syserror.ENOENT
+ return nil, topLookupLayer, syserror.ENOENT
}
// Device and inode numbers were copied from the topmost layer above;
@@ -302,14 +310,20 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
child.devMinor = fs.dirDevMinor
child.ino = fs.newDirIno()
} else if !child.upperVD.Ok() {
+ childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor)
+ if err != nil {
+ ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err)
+ child.destroyLocked(ctx)
+ return nil, topLookupLayer, err
+ }
child.devMajor = linux.UNNAMED_MAJOR
- child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ child.devMinor = childDevMinor
}
parent.IncRef()
child.parent = parent
child.name = name
- return child, nil
+ return child, topLookupLayer, nil
}
// lookupLayerLocked is similar to lookupLocked, but only returns information
@@ -408,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool {
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -428,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
d := rp.Start().Impl().(*dentry)
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -463,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if name == "." || name == ".." {
return syserror.EEXIST
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.vfsd.IsDead() {
return syserror.ENOENT
}
@@ -489,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
if err := parent.copyUpLocked(ctx); err != nil {
@@ -791,9 +806,9 @@ afterTrailingSymlink:
}
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
- fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout)
parent.dirMu.Unlock()
return fd, err
}
@@ -893,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
// Preconditions:
// * parent.dirMu must be locked.
// * parent does not already contain a child named rp.Component().
-func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) {
creds := rp.Credentials()
if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
return nil, err
@@ -918,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
Start: parent.upperVD,
Path: fspath.Parse(childName),
}
- // We don't know if a whiteout exists on the upper layer; speculatively
- // unlink it.
- //
- // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
- // know whether a whiteout exists.
- var haveUpperWhiteout bool
- switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
- case nil:
- haveUpperWhiteout = true
- case syserror.ENOENT:
- haveUpperWhiteout = false
- default:
- return nil, err
+ // Unlink the whiteout if it exists.
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err)
+ return nil, err
+ }
}
// Create the file on the upper layer, and get an FD representing it.
upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
@@ -961,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
// Re-lookup to get a dentry representing the new file, which is needed for
// the returned FD.
- child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ child, _, err := fs.getChildLocked(ctx, parent, childName, ds)
if err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr))
@@ -970,7 +978,10 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
- // Finally construct the overlay FD.
+ // Finally construct the overlay FD. Below this point, we don't perform
+ // cleanup (the file was created successfully even if we can no longer open
+ // it for some reason).
+ parent.dirents = nil
upperFlags := upperFD.StatusFlags()
fd := &regularFileFD{
copiedUp: true,
@@ -981,8 +992,6 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
upperFDOpts := upperFD.Options()
if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
upperFD.DecRef(ctx)
- // Don't bother with cleanup; the file was created successfully, we
- // just can't open it anymore for some reason.
return nil, err
}
parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */)
@@ -1040,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// directory, we need to check for write permission on it.
oldParent.dirMu.Lock()
defer oldParent.dirMu.Unlock()
- renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
+ renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
if err != nil {
return err
}
@@ -1072,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.vfsd.IsDead() {
return syserror.ENOENT
}
- replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName)
- if err != nil {
- return err
- }
var (
- replaced *dentry
- replacedVFSD *vfs.Dentry
- whiteouts map[string]bool
+ replaced *dentry
+ replacedVFSD *vfs.Dentry
+ replacedLayer lookupLayer
+ whiteouts map[string]bool
)
- if replacedLayer.existsInOverlay() {
- replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds)
- if err != nil {
- return err
- }
+ replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil && err != syserror.ENOENT {
+ return err
+ }
+ if replaced != nil {
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
@@ -1289,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// Unlike UnlinkAt, we need a dentry representing the child directory being
// removed in order to verify that it's empty.
- child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err := fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
@@ -1541,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if parentMode&linux.S_ISVTX != 0 {
// If the parent's sticky bit is set, we need a child dentry to get
// its owner.
- child, err = fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err = fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 4c5de8d32..73130bc8d 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -22,6 +22,7 @@
// filesystem.renameMu
// dentry.dirMu
// dentry.copyMu
+// filesystem.devMu
// *** "memmap.Mappable locks" below this point
// dentry.mapsMu
// *** "memmap.Mappable locks taken by Translate" below this point
@@ -33,12 +34,14 @@
package overlay
import (
+ "fmt"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -99,10 +102,15 @@ type filesystem struct {
// is immutable.
dirDevMinor uint32
- // lowerDevMinors maps lower layer filesystems to device minor numbers
- // assigned to non-directory files originating from that filesystem.
- // lowerDevMinors is immutable.
- lowerDevMinors map[*vfs.Filesystem]uint32
+ // lowerDevMinors maps device numbers from lower layer filesystems to
+ // device minor numbers assigned to non-directory files originating from
+ // that filesystem. (This remapping is necessary for lower layers because a
+ // file on a lower layer, and that same file on an overlay, are
+ // distinguishable because they will diverge after copy-up; this isn't true
+ // for non-directory files already on the upper layer.) lowerDevMinors is
+ // protected by devMu.
+ devMu sync.Mutex `state:"nosave"`
+ lowerDevMinors map[layerDevNumber]uint32
// renameMu synchronizes renaming with non-renaming operations in order to
// ensure consistent lock ordering between dentry.dirMu in different
@@ -114,78 +122,69 @@ type filesystem struct {
lastDirIno uint64
}
+// +stateify savable
+type layerDevNumber struct {
+ major uint32
+ minor uint32
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mopts := vfs.GenericParseMountOptions(opts.Data)
fsoptsRaw := opts.InternalData
- fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
- if fsoptsRaw != nil && !haveFSOpts {
+ fsopts, ok := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !ok {
ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
return nil, nil, syserror.EINVAL
}
- if haveFSOpts {
- if len(fsopts.LowerRoots) == 0 {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+
+ if upperPathname, ok := mopts["upperdir"]; ok {
+ if fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified")
return nil, nil, syserror.EINVAL
}
- if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
return nil, nil, syserror.EINVAL
}
- // We don't enforce a maximum number of lower layers when not
- // configured by applications; the sandbox owner can have an overlay
- // filesystem with any number of lower layers.
- } else {
- vfsroot := vfs.RootFromContext(ctx)
- defer vfsroot.DecRef(ctx)
- upperPathname, ok := mopts["upperdir"]
- if ok {
- delete(mopts, "upperdir")
- // Linux overlayfs also requires a workdir when upperdir is
- // specified; we don't, so silently ignore this option.
- delete(mopts, "workdir")
- upperPath := fspath.Parse(upperPathname)
- if !upperPath.Absolute {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
- return nil, nil, syserror.EINVAL
- }
- upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
- Root: vfsroot,
- Start: vfsroot,
- Path: upperPath,
- FollowFinalSymlink: true,
- }, &vfs.GetDentryOptions{
- CheckSearchable: true,
- })
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer upperRoot.DecRef(ctx)
- privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer privateUpperRoot.DecRef(ctx)
- fsopts.UpperRoot = privateUpperRoot
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ upperRoot.DecRef(ctx)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
}
- lowerPathnamesStr, ok := mopts["lowerdir"]
- if !ok {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ defer privateUpperRoot.DecRef(ctx)
+ fsopts.UpperRoot = privateUpperRoot
+ }
+
+ if lowerPathnamesStr, ok := mopts["lowerdir"]; ok {
+ if len(fsopts.LowerRoots) != 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified")
return nil, nil, syserror.EINVAL
}
delete(mopts, "lowerdir")
lowerPathnames := strings.Split(lowerPathnamesStr, ":")
- const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
- if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
- return nil, nil, syserror.EINVAL
- }
- if len(lowerPathnames) > maxLowerLayers {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
- return nil, nil, syserror.EINVAL
- }
for _, lowerPathname := range lowerPathnames {
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
@@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
- defer lowerRoot.DecRef(ctx)
privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ lowerRoot.DecRef(ctx)
if err != nil {
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
@@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
}
}
+
if len(mopts) != 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
return nil, nil, syserror.EINVAL
}
- // Allocate device numbers.
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present")
+ return nil, nil, syserror.EINVAL
+ }
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(fsopts.LowerRoots) > maxLowerLayers {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate dirDevMinor. lowerDevMinors are allocated dynamically.
dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
- lowerDevMinors := make(map[*vfs.Filesystem]uint32)
- for _, lowerRoot := range fsopts.LowerRoots {
- lowerFS := lowerRoot.Mount().Filesystem()
- if _, ok := lowerDevMinors[lowerFS]; !ok {
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- vfsObj.PutAnonBlockDevMinor(dirDevMinor)
- for _, lowerDevMinor := range lowerDevMinors {
- vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
- }
- return nil, nil, err
- }
- lowerDevMinors[lowerFS] = devMinor
- }
- }
// Take extra references held by the filesystem.
if fsopts.UpperRoot.Ok() {
@@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
opts: fsopts,
creds: creds.Fork(),
dirDevMinor: dirDevMinor,
- lowerDevMinors: lowerDevMinors,
+ lowerDevMinors: make(map[layerDevNumber]uint32),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
root.ino = fs.newDirIno()
} else if !root.upperVD.Ok() {
root.devMajor = linux.UNNAMED_MAJOR
- root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err)
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+ root.devMinor = rootDevMinor
root.ino = rootStat.Ino
} else {
root.devMajor = rootStat.DevMajor
@@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 {
return atomic.AddUint64(&fs.lastDirIno, 1)
}
+func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) {
+ fs.devMu.Lock()
+ defer fs.devMu.Unlock()
+ orig := layerDevNumber{layerMajor, layerMinor}
+ if minor, ok := fs.lowerDevMinors[orig]; ok {
+ return minor, nil
+ }
+ minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor()
+ if err != nil {
+ return 0, err
+ }
+ fs.lowerDevMinors[orig] = minor
+ return minor, nil
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -458,9 +479,9 @@ type dentry struct {
//
// - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is
// accessed using atomic memory operations.
- mapsMu sync.Mutex
+ mapsMu sync.Mutex `state:"nosave"`
lowerMappings memmap.MappingSet
- dataMu sync.RWMutex
+ dataMu sync.RWMutex `state:"nosave"`
wrappedMappable memmap.Mappable
isMappable uint32
@@ -484,6 +505,7 @@ func (fs *filesystem) newDentry() *dentry {
}
d.lowerVDs = d.inlineLowerVDs[:0]
d.vfsd.Init(d)
+ refsvfs2.Register(d)
return d
}
@@ -491,17 +513,19 @@ func (fs *filesystem) newDentry() *dentry {
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkDropLocked().
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -509,15 +533,27 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
+ } else if r < 0 {
panic("overlay.dentry.DecRef() called without holding a reference")
}
}
+func (d *dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.checkDropLocked(ctx)
+ } else if r < 0 {
+ panic("overlay.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
// checkDropLocked should be called after d's reference count becomes 0 or it
// becomes deleted.
//
@@ -577,12 +613,27 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.parent.dirMu.Unlock()
// Drop the reference held by d on its parent without recursively
// locking d.fs.renameMu.
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkDropLocked(ctx)
- } else if refs < 0 {
- panic("overlay.dentry.DecRef() called without holding a reference")
- }
+ d.parent.decRefLocked(ctx)
}
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "overlay.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -645,6 +696,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry {
return vd
}
+func (d *dentry) topLookupLayer() lookupLayer {
+ if d.upperVD.Ok() {
+ return lookupLayerUpper
+ }
+ return lookupLayerLower
+}
+
func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go
new file mode 100644
index 000000000..54809f16c
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index e44b79b68..0ecb592cf 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -101,7 +101,7 @@ type inode struct {
func newInode(ctx context.Context, fs *filesystem) *inode {
creds := auth.CredentialsFromContext(ctx)
return &inode{
- pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize),
ino: fs.Filesystem.NextIno(),
uid: creds.EffectiveKUID,
gid: creds.EffectiveKGID,
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 2e086e34c..5196a2a80 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "fd_dir_inode_refs.go",
package = "proc",
prefix = "fdDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdDirInode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "fd_info_dir_inode_refs.go",
package = "proc",
prefix = "fdInfoDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdInfoDirInode",
},
@@ -30,7 +30,7 @@ go_template_instance(
out = "subtasks_inode_refs.go",
package = "proc",
prefix = "subtasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "subtasksInode",
},
@@ -41,7 +41,7 @@ go_template_instance(
out = "task_inode_refs.go",
package = "proc",
prefix = "taskInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "taskInode",
},
@@ -52,7 +52,7 @@ go_template_instance(
out = "tasks_inode_refs.go",
package = "proc",
prefix = "tasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tasksInode",
},
@@ -82,6 +82,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index fd70a07de..8716d0a3c 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -17,6 +17,7 @@ package proc
import (
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -24,10 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "proc"
+const (
+ // Name is the default filesystem name.
+ Name = "proc"
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType is the factory class for procfs.
//
@@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
if err != nil {
return nil, nil, err
}
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
procfs := &filesystem{
devMinor: devMinor,
}
+ procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
var cgroups map[string]string
@@ -74,9 +92,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
cgroups = data.Cgroups
}
- inode := procfs.newTasksInode(k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
var dentry kernfs.Dentry
- dentry.Init(&procfs.Filesystem, inode)
+ dentry.InitRoot(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
@@ -94,11 +112,11 @@ type dynamicInode interface {
kernfs.Inode
vfs.DynamicBytesSource
- Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+ Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
}
-func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
+func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
return inode
}
@@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile {
return &staticFile{StaticData: vfs.StaticData{Data: data}}
}
-func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
- return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
+func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
+ return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index bad2fab4f..cb3c5e0fd 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
cgroupControllers: cgroupControllers,
}
// Note: credentials are overridden by taskOwnedInode.
- subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
subInode.EnableLeakCheck()
@@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
return offset, syserror.ENOENT
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index b63a4eca0..19011b010 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
"gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
"io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
"maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mem": fs.newMemInode(task, fs.NextIno(), 0400),
"mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
"mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}),
"net": fs.newTaskNetDir(task),
@@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode := &taskInode{task: task}
// Note: credentials are overridden by taskOwnedInode.
- taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
taskInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
@@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil)
func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
return &taskOwnedInode{Inode: inode, owner: task}
}
@@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu
func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero}
- dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
+ dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
return &taskOwnedInode{Inode: dir, owner: task}
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 2c80ac5c2..d268b44be 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -64,7 +64,7 @@ type fdDir struct {
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
@@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
produceSymlink: true,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Lookup implements kernfs.inodeDirectory.Lookup.
@@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern
task: task,
fd: fd,
}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
task: task,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
@@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements Inode.IterDirents.
-func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 79f8b7e9f..ba71d0fde 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -249,7 +250,7 @@ type commInode struct {
func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
inode := &commInode{task: task}
- inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
+ inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
return inode
}
@@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
return int64(srclen), nil
}
+var _ kernfs.Inode = (*memInode)(nil)
+
+// memInode implements kernfs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoStatFS
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ task *kernel.Task
+ locks vfs.FileLocks
+}
+
+func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode := &memInode{task: task}
+ inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm)
+ return &taskOwnedInode{Inode: inode, owner: task}
+}
+
+func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+}
+
+// Open implements kernfs.Inode.Open.
+func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, f.task, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(f.task); err != nil {
+ return nil, err
+ }
+ fd := &memFD{}
+ if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+var _ vfs.FileDescriptionImpl = (*memFD)(nil)
+
+// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem.
+//
+// +stateify savable
+type memFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *memInode
+
+ // mu guards the fields below.
+ mu sync.Mutex `state:"nosave"`
+ offset int64
+}
+
+// Init initializes memFD.
+func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error {
+ fd.LockFD.Init(&inode.locks)
+ if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return err
+ }
+ fd.inode = inode
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ case linux.SEEK_CUR:
+ offset += fd.offset
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.offset = offset
+ return offset, nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ m, err := getMMIncRef(fd.inode.task)
+ if err != nil {
+ return 0, nil
+ }
+ defer m.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.offset, opts)
+ fd.offset += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *memFD) Release(context.Context) {}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil)
func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &exeSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil)
func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &cwdSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
inode := &namespaceSymlink{task: task}
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
taskInode := &taskOwnedInode{Inode: inode, owner: task}
return taskInode
@@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
// Create a synthetic inode to represent the namespace.
fs := mnt.Filesystem().Impl().(*filesystem)
+ nsInode := &namespaceInode{}
+ nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444)
dentry := &kernfs.Dentry{}
- dentry.Init(&fs.Filesystem, &namespaceInode{})
+ dentry.Init(&fs.Filesystem, nsInode)
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
// Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1.
mnt.IncRef()
@@ -897,11 +1056,11 @@ type namespaceInode struct {
var _ kernfs.Inode = (*namespaceInode)(nil)
// Init initializes a namespace inode.
-func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 3425e8698..5a9ee111f 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode {
// TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
// network namespace.
contents = map[string]kernfs.Inode{
- "dev": fs.newInode(root, 0444, &netDevData{stack: stack}),
- "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}),
+ "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}),
+ "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, if the file contains a header the stub is just the header
// otherwise it is an empty file.
- "arp": fs.newInode(root, 0444, newStaticFile(arp)),
- "netlink": fs.newInode(root, 0444, newStaticFile(netlink)),
- "netstat": fs.newInode(root, 0444, &netStatData{}),
- "packet": fs.newInode(root, 0444, newStaticFile(packet)),
- "protocols": fs.newInode(root, 0444, newStaticFile(protocols)),
+ "arp": fs.newInode(task, root, 0444, newStaticFile(arp)),
+ "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)),
+ "netstat": fs.newInode(task, root, 0444, &netStatData{}),
+ "packet": fs.newInode(task, root, 0444, newStaticFile(packet)),
+ "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)),
// Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
// high res timer ticks per sec (ClockGetres returns 1ns resolution).
- "psched": fs.newInode(root, 0444, newStaticFile(psched)),
- "ptype": fs.newInode(root, 0444, newStaticFile(ptype)),
- "route": fs.newInode(root, 0444, &netRouteData{stack: stack}),
- "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}),
- "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}),
- "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}),
+ "psched": fs.newInode(task, root, 0444, newStaticFile(psched)),
+ "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)),
+ "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}),
+ "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}),
+ "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}),
}
if stack.SupportsIPv6() {
- contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack})
- contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile(""))
- contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k})
- contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6))
+ contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6))
}
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 3259c3732..b81ea14bf 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -62,19 +62,19 @@ type tasksInode struct {
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
- "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))),
- "filesystems": fs.newInode(root, 0444, &filesystemsData{}),
- "loadavg": fs.newInode(root, 0444, &loadavgData{}),
- "sys": fs.newSysDir(root, k),
- "meminfo": fs.newInode(root, 0444, &meminfoData{}),
- "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
- "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
- "stat": fs.newInode(root, 0444, &statData{}),
- "uptime": fs.newInode(root, 0444, &uptimeData{}),
- "version": fs.newInode(root, 0444, &versionData{}),
+ "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}),
+ "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}),
+ "sys": fs.newSysDir(ctx, root, k),
+ "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
+ "stat": fs.newInode(ctx, root, 0444, &statData{}),
+ "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
+ "version": fs.newInode(ctx, root, 0444, &versionData{}),
}
inode := &tasksInode{
@@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace
fs: fs,
cgroupControllers: cgroupControllers,
}
- inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
@@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
// If it failed to parse, check if it's one of the special handled files.
switch name {
case selfName:
- return i.newSelfSymlink(root), nil
+ return i.newSelfSymlink(ctx, root), nil
case threadSelfName:
- return i.newThreadSelfSymlink(root), nil
+ return i.newThreadSelfSymlink(ctx, root), nil
}
return nil, syserror.ENOENT
}
@@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 07c27cdd9..01b7a6678 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -43,9 +43,9 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &selfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
@@ -84,9 +84,9 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &threadSelfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 95420368d..7c7afdcfa 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -40,93 +40,93 @@ const (
)
// newSysDir returns the dentry corresponding to /proc/sys directory.
-func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
- return fs.newStaticDir(root, map[string]kernfs.Inode{
- "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{
- "hostname": fs.newInode(root, 0444, &hostnameData{}),
- "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)),
+func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
}),
- "vm": fs.newStaticDir(root, map[string]kernfs.Inode{
- "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}),
- "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")),
+ "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")),
}),
- "net": fs.newSysNetDir(root, k),
+ "net": fs.newSysNetDir(ctx, root, k),
})
}
// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
-func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
var contents map[string]kernfs.Inode
// TODO(gvisor.dev/issue/1833): Support for using the network stack in the
// network namespace of the calling process.
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]kernfs.Inode{
- "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{
- "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}),
+ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")),
- "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")),
- "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")),
- "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")),
- "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")),
+ "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")),
// tcp_allowed_congestion_control tell the user what they are able to
// do as an unprivledged process so we leave it empty.
- "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
- "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
+ "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
// Many of the following stub files are features netstack doesn't
// support. The unsupported features return "0" to indicate they are
// disabled.
- "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")),
- "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")),
- "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")),
- "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")),
- "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")),
}),
- "core": fs.newStaticDir(root, map[string]kernfs.Inode{
- "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")),
- "message_burst": fs.newInode(root, 0444, newStaticFile("10")),
- "message_cost": fs.newInode(root, 0444, newStaticFile("5")),
- "optmem_max": fs.newInode(root, 0444, newStaticFile("0")),
- "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
- "somaxconn": fs.newInode(root, 0444, newStaticFile("128")),
- "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
+ "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")),
+ "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")),
+ "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
}),
}
}
- return fs.newStaticDir(root, contents)
+ return fs.newStaticDir(ctx, root, contents)
}
// mmapMinAddrData implements vfs.DynamicBytesSource for
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 2582ababd..7ee6227a9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -77,6 +77,7 @@ var (
"gid_map": linux.DT_REG,
"io": linux.DT_REG,
"maps": linux.DT_REG,
+ "mem": linux.DT_REG,
"mountinfo": linux.DT_REG,
"mounts": linux.DT_REG,
"net": linux.DT_DIR,
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index cf91ea36c..fda1fa942 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// NewDentry constructs and returns a sockfs dentry.
//
// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
-func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
+func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry {
fs := mnt.Filesystem().Impl().(*filesystem)
// File mode matches net/socket.c:sock_alloc.
filemode := linux.FileMode(linux.S_IFSOCK | 0600)
i := &inode{}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
+ i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
d := &kernfs.Dentry{}
d.Init(&fs.Filesystem, i)
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 906cd52cb..09043b572 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "dir_refs.go",
package = "sys",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/coverage",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index 31a361029..b13f141a8 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -29,7 +29,7 @@ import (
func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
k := &kcovInode{}
- k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
+ k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
return k
}
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1ad679830..506a2a0f0 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -18,6 +18,7 @@ package sys
import (
"bytes"
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -29,9 +30,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "sysfs"
-const defaultSysDirMode = linux.FileMode(0755)
+const (
+ // Name is the default filesystem name.
+ Name = "sysfs"
+ defaultSysDirMode = linux.FileMode(0755)
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType implements vfs.FilesystemType.
//
@@ -62,31 +66,43 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
fs := &filesystem{
devMinor: devMinor,
}
+ fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
- root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "block": fs.newDir(creds, defaultSysDirMode, nil),
- "bus": fs.newDir(creds, defaultSysDirMode, nil),
- "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "power_supply": fs.newDir(creds, defaultSysDirMode, nil),
+ root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil),
}),
- "dev": fs.newDir(creds, defaultSysDirMode, nil),
- "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"cpu": cpuDir(ctx, fs, creds),
}),
}),
- "firmware": fs.newDir(creds, defaultSysDirMode, nil),
- "fs": fs.newDir(creds, defaultSysDirMode, nil),
+ "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"kernel": kernelDir(ctx, fs, creds),
- "module": fs.newDir(creds, defaultSysDirMode, nil),
- "power": fs.newDir(creds, defaultSysDirMode, nil),
+ "module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
})
var rootD kernfs.Dentry
- rootD.Init(&fs.Filesystem, root)
+ rootD.InitRoot(&fs.Filesystem, root)
return fs.VFSFilesystem(), rootD.VFSDentry(), nil
}
@@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
children := map[string]kernfs.Inode{
- "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
}
for i := uint(0); i < maxCPUCores; i++ {
- children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil)
+ children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil)
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
@@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker
var children map[string]kernfs.Inode
if coverage.KcovAvailable() {
children = map[string]kernfs.Inode{
- "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{
+ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{
"kcov": fs.newKcovFile(ctx, creds),
}),
}
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
// Release implements vfs.FilesystemImpl.Release.
@@ -140,9 +156,9 @@ type dir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
d := &dir{}
- d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
d.EnableLeakCheck()
d.IncLinks(d.OrderedChildren.Populate(contents))
@@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
+func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
c := &cpuFile{maxCores: maxCores}
- c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
+ c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
return c
}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 5cd428d64..fe520b6fd 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -31,7 +31,7 @@ go_template_instance(
out = "inode_refs.go",
package = "tmpfs",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -48,6 +48,7 @@ go_library(
"inode_refs.go",
"named_pipe.go",
"regular_file.go",
+ "save_restore.go",
"socket_file.go",
"symlink.go",
"tmpfs.go",
@@ -60,6 +61,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index d772db9e9..57e7b57b0 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
@@ -32,7 +31,7 @@ type namedPipe struct {
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)}
file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index ce4e3eda7..98680fde9 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -42,7 +42,7 @@ type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
- memFile *pgalloc.MemoryFile
+ memFile *pgalloc.MemoryFile `state:"nosave"`
// memoryUsageKind is the memory accounting category under which pages backing
// this regularFile's contents are accounted.
@@ -92,7 +92,7 @@ type regularFile struct {
func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
- memFile: fs.memFile,
+ memFile: fs.mfp.MemoryFile(),
memoryUsageKind: usage.Tmpfs,
seals: linux.F_SEAL_SEAL,
}
diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go
new file mode 100644
index 000000000..b27f75cc2
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+// afterLoad is called by stateify.
+func (rf *regularFile) afterLoad() {
+ rf.memFile = rf.inode.fs.mfp.MemoryFile()
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index e2a0aac69..4ce859d57 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -61,8 +61,9 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
- // memFile is used to allocate pages to for regular files.
- memFile *pgalloc.MemoryFile
+ // mfp is used to allocate memory that stores regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
// clock is a realtime clock used to set timestamps in file operations.
clock time.Clock
@@ -106,8 +107,8 @@ type FilesystemOpts struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
- if memFileProvider == nil {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
panic("MemoryFileProviderFromContext returned nil")
}
@@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
clock := time.RealtimeClockFromContext(ctx)
fs := filesystem{
- memFile: memFileProvider.MemoryFile(),
+ mfp: mfp,
clock: clock,
devMinor: devMinor,
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 0ca750281..e265be0ee 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "verity",
srcs = [
"filesystem.go",
+ "save_restore.go",
"verity.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -15,6 +16,7 @@ go_library(
"//pkg/fspath",
"//pkg/marshal/primitive",
"//pkg/merkletree",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
@@ -38,10 +40,12 @@ go_test(
"//pkg/context",
"//pkg/fspath",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 03da505e1..4e8d63d51 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's hash.
@@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Start: parent.lowerVD,
}, &vfs.StatOptions{})
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// Verify returns with success.
var buf bytes.Buffer
if _, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: &buf,
- File: &fdReader,
- Tree: &fdReader,
- Size: int64(parentSize),
- Name: parent.name,
- Mode: uint32(parentStat.Mode),
- UID: parentStat.UID,
- GID: parentStat.GID,
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
- ReadSize: int64(merkletree.DigestSize()),
+ ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
}); err != nil && err != io.EOF {
- return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
@@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
return err
@@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return err
@@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
size, err := strconv.Atoi(merkleSize)
if err != nil {
- return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
var buf bytes.Buffer
params := &merkletree.VerifyParams{
- Out: &buf,
- Tree: &fdReader,
- Size: int64(size),
- Name: d.name,
- Mode: uint32(stat.Mode),
- UID: stat.UID,
- GID: stat.GID,
- ReadOffset: 0,
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
+ ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
ReadSize: 0,
Expected: d.hash,
@@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
}
if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
- return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
}
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
+ d.size = uint32(size)
return nil
}
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
+ // If verity is enabled on child, we should check again whether
+ // the file and the corresponding Merkle tree are as expected,
+ // in order to catch deletion/renaming after the last time it's
+ // accessed.
+ if child.verityEnabled() {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ // Get the path to the child dentry. This is only used
+ // to provide path information in failure case.
+ path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD)
+ if err != nil {
+ return nil, err
+ }
+
+ childVD, err := parent.getLowerAt(ctx, vfsObj, name)
+ if err == syserror.ENOENT {
+ // The file was previously accessed. If the
+ // file does not exist now, it indicates an
+ // unexpected modification to the file system.
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer childVD.DecRef(ctx)
+
+ childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
+ // The Merkle tree file was previous accessed. If it
+ // does not exist now, it indicates an unexpected
+ // modification to the file system.
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ defer childMerkleVD.DecRef(ctx)
+ }
+
// If enabling verification on files/directories is not allowed
// during runtime, all cached children are already verified. If
// runtime enable is allowed and the parent directory is
@@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
- childFilename := fspath.Parse(name)
- childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: childFilename,
- }, &vfs.GetDentryOptions{})
-
+ childVD, childErr := parent.getLowerAt(ctx, vfsObj, name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childErr != nil && childErr != syserror.ENOENT {
@@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
defer childVD.DecRef(ctx)
}
- childMerkleFilename := merklePrefix + name
- childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
-
+ childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childMerkleErr != nil && childMerkleErr != syserror.ENOENT {
@@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// corresponding Merkle tree is found. This indicates an
// unexpected modification to the file system that
// removed/renamed the child.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
} else if childErr == nil && childMerkleErr == syserror.ENOENT {
// If in allowRuntimeEnable mode, and the Merkle tree file is
// not created yet, we create an empty Merkle tree file, so that
@@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: parent.lowerVD,
Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
+ Path: fspath.Parse(merklePrefix + name),
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT,
Mode: 0644,
@@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
return nil, err
}
childMerkleFD.DecRef(ctx)
- childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
+ childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
if err != nil {
return nil, err
}
@@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// If runtime enable is not allowed. This indicates an
// unexpected modification to the file system that
// removed/renamed the Merkle tree file.
- return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
}
} else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT {
// Both the child and the corresponding Merkle tree are missing.
@@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// TODO(b/167752508): Investigate possible ways to differentiate
// cases that both files are deleted from cases that they never
// exist in the file system.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
}
mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
@@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// missing, it indicates an unexpected modification to the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
}
@@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
})
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err != nil {
if err == syserror.ENOENT {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
return nil, err
}
diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go
new file mode 100644
index 000000000..46b064342
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 8dc9e26bc..d24c839bb 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -23,6 +23,7 @@ package verity
import (
"fmt"
+ "math"
"strconv"
"sync/atomic"
@@ -31,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -41,32 +43,62 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// Name is the default filesystem name.
-const Name = "verity"
+const (
+ // Name is the default filesystem name.
+ Name = "verity"
-// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle
-// tree file for "/foo" is "/.merkle.verity.foo".
-const merklePrefix = ".merkle.verity."
+ // merklePrefix is the prefix of the Merkle tree files. For example, the Merkle
+ // tree file for "/foo" is "/.merkle.verity.foo".
+ merklePrefix = ".merkle.verity."
-// merkleoffsetInParentXattr is the extended attribute name specifying the
-// offset of child hash in its parent's Merkle tree.
-const merkleOffsetInParentXattr = "user.merkle.offset"
+ // merkleOffsetInParentXattr is the extended attribute name specifying the
+ // offset of the child hash in its parent's Merkle tree.
+ merkleOffsetInParentXattr = "user.merkle.offset"
-// merkleSizeXattr is the extended attribute name specifying the size of data
-// hashed by the corresponding Merkle tree. For a file, it's the size of the
-// whole file. For a directory, it's the size of all its children's hashes.
-const merkleSizeXattr = "user.merkle.size"
+ // merkleSizeXattr is the extended attribute name specifying the size of data
+ // hashed by the corresponding Merkle tree. For a regular file, this is the
+ // file size. For a directory, this is the size of all its children's hashes.
+ merkleSizeXattr = "user.merkle.size"
-// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
-// extended attributes. The maximum value of a 32 bit integer is 10 digits.
-const sizeOfStringInt32 = 10
+ // sizeOfStringInt32 is the size for a 32 bit integer stored as string in
+ // extended attributes. The maximum value of a 32 bit integer has 10 digits.
+ sizeOfStringInt32 = 10
+)
-// noCrashOnVerificationFailure indicates whether the sandbox should panic
-// whenever verification fails. If true, an error is returned instead of
-// panicking. This should only be set for tests.
-// TOOD(b/165661693): Decide whether to panic or return error based on this
-// flag.
-var noCrashOnVerificationFailure bool
+var (
+ // noCrashOnVerificationFailure indicates whether the sandbox should panic
+ // whenever verification fails. If true, an error is returned instead of
+ // panicking. This should only be set for tests.
+ //
+ // TODO(b/165661693): Decide whether to panic or return error based on this
+ // flag.
+ noCrashOnVerificationFailure bool
+
+ // verityMu synchronizes concurrent operations that enable verity and perform
+ // verification checks.
+ verityMu sync.RWMutex
+)
+
+// HashAlgorithm is a type specifying the algorithm used to hash the file
+// content.
+type HashAlgorithm int
+
+// Currently supported hashing algorithms include SHA256 and SHA512.
+const (
+ SHA256 HashAlgorithm = iota
+ SHA512
+)
+
+func (alg HashAlgorithm) toLinuxHashAlg() int {
+ switch alg {
+ case SHA256:
+ return linux.FS_VERITY_HASH_ALG_SHA256
+ case SHA512:
+ return linux.FS_VERITY_HASH_ALG_SHA512
+ default:
+ return 0
+ }
+}
// FilesystemType implements vfs.FilesystemType.
//
@@ -97,6 +129,10 @@ type filesystem struct {
// stores the root hash of the whole file system in bytes.
rootDentry *dentry
+ // alg is the algorithms used to hash the files in the verity file
+ // system.
+ alg HashAlgorithm
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -125,6 +161,10 @@ type InternalFilesystemOptions struct {
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
+ // Alg is the algorithms used to hash the files in the verity file
+ // system.
+ Alg HashAlgorithm
+
// RootHash is the root hash of the overall verity file system.
RootHash []byte
@@ -153,10 +193,10 @@ func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
-func alertIntegrityViolation(err error, msg string) error {
+// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+func alertIntegrityViolation(msg string) error {
if noCrashOnVerificationFailure {
- return err
+ return syserror.EIO
}
panic(msg)
}
@@ -183,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs := &filesystem{
creds: creds.Fork(),
+ alg: iopts.Alg,
lowerMount: mnt,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
@@ -236,7 +277,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// the root Merkle file, or it's never generated.
fs.vfsfs.DecRef(ctx)
d.DecRef(ctx)
- return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file")
+ return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
}
d.lowerMerkleVD = lowerMerkleVD
@@ -289,11 +330,12 @@ type dentry struct {
// fs is the owning filesystem. fs is immutable.
fs *filesystem
- // mode, uid and gid are the file mode, owner, and group of the file in
- // the underlying file system.
+ // mode, uid, gid and size are the file mode, owner, group, and size of
+ // the file in the underlying file system.
mode uint32
uid uint32
gid uint32
+ size uint32
// parent is the dentry corresponding to this dentry's parent directory.
// name is this dentry's name in parent. If this dentry is a filesystem
@@ -331,22 +373,25 @@ func (fs *filesystem) newDentry() *dentry {
fs: fs,
}
d.vfsd.Init(d)
+ refsvfs2.Register(d)
return d
}
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -354,15 +399,27 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
+ } else if r < 0 {
panic("verity.dentry.DecRef() called without holding a reference")
}
}
+func (d *dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.checkDropLocked(ctx)
+ } else if r < 0 {
+ panic("verity.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
// checkDropLocked should be called after d's reference count becomes 0 or it
// becomes deleted.
func (d *dentry) checkDropLocked(ctx context.Context) {
@@ -393,23 +450,36 @@ func (d *dentry) destroyLocked(ctx context.Context) {
if d.lowerVD.Ok() {
d.lowerVD.DecRef(ctx)
}
-
if d.lowerMerkleVD.Ok() {
d.lowerMerkleVD.DecRef(ctx)
}
-
if d.parent != nil {
d.parent.dirMu.Lock()
if !d.vfsd.IsDead() {
delete(d.parent.children, d.name)
}
d.parent.dirMu.Unlock()
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkDropLocked(ctx)
- } else if refs < 0 {
- panic("verity.dentry.DecRef() called without holding a reference")
- }
+ d.parent.decRefLocked(ctx)
}
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "verity.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -448,6 +518,16 @@ func (d *dentry) verityEnabled() bool {
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
+// getLowerAt returns the dentry in the underlying file system, which is
+// represented by filename relative to d.
+func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) {
+ return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.GetDentryOptions{})
+}
+
func (d *dentry) readlink(ctx context.Context) (string, error) {
return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
Root: d.lowerVD,
@@ -489,6 +569,10 @@ type fileDescription struct {
// directory that contains the current file/directory. This is only used
// if allowRuntimeEnable is set to true.
parentMerkleWriter *vfs.FileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex `state:"nosave"`
+ off int64
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -524,6 +608,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ n := int64(0)
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ n = fd.off
+ case linux.SEEK_END:
+ n = int64(fd.d.size)
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset > math.MaxInt64-n {
+ return 0, syserror.EINVAL
+ }
+ offset += n
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
// generateMerkle generates a Merkle tree file for fd. If fd points to a file
// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
// of the generated Merkle tree and the data size is returned. If fd points to
@@ -546,6 +656,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
params := &merkletree.GenerateParams{
TreeReader: &merkleReader,
TreeWriter: &merkleWriter,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
}
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
@@ -611,7 +723,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// or directory other than the root, the parent Merkle tree file should
// have also been initialized.
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
- return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
+ return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
hash, dataSize, err := fd.generateMerkle(ctx)
@@ -657,6 +769,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// measureVerity returns the hash of fd, saved in verityDigest.
func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
var metadata linux.DigestMetadata
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
@@ -667,7 +782,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve
if fd.d.fs.allowRuntimeEnable {
return 0, syserror.ENODATA
}
- return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found")
+ return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found")
}
// The first part of VerityDigest is the metadata.
@@ -702,6 +817,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag
}
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
_, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -722,6 +840,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
}
}
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Implement Read with PRead by setting offset.
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
@@ -742,7 +870,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
return 0, err
@@ -752,7 +880,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
dataReader := vfs.FileReadWriteSeeker{
@@ -766,25 +894,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
}
n, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: dst.Writer(ctx),
- File: &dataReader,
- Tree: &merkleReader,
- Size: int64(size),
- Name: fd.d.name,
- Mode: fd.d.mode,
- UID: fd.d.uid,
- GID: fd.d.gid,
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
if err != nil {
- return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index e301d35f5..b2da9dd96 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -25,10 +25,12 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -41,11 +43,18 @@ const maxDataSize = 100000
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
-func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("testutil.Boot: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+
rand.Seed(time.Now().UnixNano())
vfsObj := &vfs.VirtualFilesystem{}
if err := vfsObj.Init(ctx); err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -61,22 +70,33 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v
InternalData: InternalFilesystemOptions{
RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
+ Alg: hashAlg,
AllowRuntimeEnable: true,
NoCrashOnVerificationFailure: true,
},
},
})
if err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err)
}
root := mntns.Root()
root.IncRef()
+
+ // Use lowerRoot in the task as we modify the lower file system
+ // directly in many tests.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot)
+ if err != nil {
+ t.Fatalf("testutil.CreateTask: %v", err)
+ }
+
t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
})
- return vfsObj, root, nil
+ return vfsObj, root, task, nil
}
// newFileFD creates a new file in the verity mount, and returns the FD. The FD
@@ -142,207 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
func TestOpen(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Ensure that the corresponding Merkle tree file is created.
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
}
+}
- // Ensure the root merkle tree file is created.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + rootMerkleFilename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity
+// file succeeds after enabling verity for it.
+func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
-// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
-// succeeds after enabling verity for it.
+// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity
+// file succeeds after enabling verity for it.
func TestReadUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- buf := make([]byte, size)
- n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
- if err != nil && err != io.EOF {
- t.Fatalf("fd.PRead: %v", err)
- }
-
- if n != int64(size) {
- t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Ensure reopening the verity enabled file succeeds.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != nil {
- t.Errorf("reopen enabled file failed: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
}
}
-// TestModifiedFileFails ensures that read from a modified verity file fails.
-func TestModifiedFileFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+// TestPReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestPReadModifiedFileFails(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded, expected failure")
+ }
}
+}
- // Confirm that read from the modified file fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- t.Fatalf("fd.PRead succeeded with modified file")
+// TestReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestReadModifiedFileFails(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.Read succeeded, expected failure")
+ }
}
}
// TestModifiedMerkleFails ensures that read from a verity file fails if the
// corresponding Merkle tree file is modified.
func TestModifiedMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
-
- lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerMerkleVD,
- Start: lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the Merkle tree file.
- stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- merkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- // Confirm that read from a file with modified Merkle tree fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
- t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
}
}
@@ -350,142 +459,267 @@ func TestModifiedMerkleFails(t *testing.T) {
// verity enabled directory fails if the hashes related to the target file in
// the parent Merkle tree file is modified.
func TestModifiedParentMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
-
- parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: parentLowerMerkleVD,
- Start: parentLowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the parent Merkle tree file.
- // This parent directory contains only one child, so any random
- // modification in the parent Merkle tree should cause verification
- // failure when opening the child file.
- stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- parentMerkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- parentLowerMerkleFD.DecRef(ctx)
-
- // Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err == nil {
- t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
}
}
// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
// succeeds after enabling verity for it.
func TestUnmodifiedStatSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
- t.Errorf("fd.Stat: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
}
}
// TestModifiedStatFails checks that getting stat for a file with modified stat
// should fail.
func TestModifiedStatFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
}
+}
- lowerFD := fd.Impl().(*fileDescription).lowerFD
- // Change the stat of the underlying file, and check that stat fails.
- if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: uint32(linux.STATX_MODE),
- Mode: 0777,
+// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed
+// verity enabled file or the corresponding Merkle tree file fails with the
+// verify error.
+func TestOpenDeletedFileFails(t *testing.T) {
+ testCases := []struct {
+ // Tests removing files is remove is true. Otherwise tests
+ // renaming files.
+ remove bool
+ // The original file is removed/renamed if changeFile is true.
+ changeFile bool
+ // The Merkle tree file is removed/renamed if changeMerkleFile
+ // is true.
+ changeMerkleFile bool
+ }{
+ {
+ remove: true,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: true,
+ changeFile: false,
+ changeMerkleFile: true,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
},
- }); err != nil {
- t.Fatalf("lowerFD.SetStat: %v", err)
}
-
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
- t.Errorf("fd.Stat succeeded when it should fail")
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
+ if tc.remove {
+ if tc.changeFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ } else {
+ newFilename := "renamed-test-file"
+ if tc.changeFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("RenameAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ }
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != syserror.EIO {
+ t.Errorf("got OpenAt error: %v, expected EIO", err)
+ }
+ }
+ })
}
}
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
index 364a78306..db3b0d0a0 100644
--- a/pkg/sentry/hostfd/BUILD
+++ b/pkg/sentry/hostfd/BUILD
@@ -6,10 +6,12 @@ go_library(
name = "hostfd",
srcs = [
"hostfd.go",
+ "hostfd_linux.go",
"hostfd_unsafe.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/log",
"//pkg/safemem",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go
new file mode 100644
index 000000000..1cabc848f
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_linux.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+// maxIov is the maximum permitted size of a struct iovec array.
+const maxIov = 1024 // UIO_MAXIOV
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
index cd4dc67fb..694371b1c 100644
--- a/pkg/sentry/hostfd/hostfd_unsafe.go
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -20,6 +20,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
)
@@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6
}
} else {
iovs := safemem.IovecsFromBlockSeq(dsts)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
@@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint
}
} else {
iovs := safemem.IovecsFromBlockSeq(srcs)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index fbe6d6aa6..f31277d30 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -32,9 +32,13 @@ type Stack interface {
InterfaceAddrs() map[int32][]InterfaceAddr
// AddInterfaceAddr adds an address to the network interface identified by
- // index.
+ // idx.
AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+ // RemoveInterfaceAddr removes an address from the network interface
+ // identified by idx.
+ RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error
+
// SupportsIPv6 returns true if the stack supports IPv6 connectivity.
SupportsIPv6() bool
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 1779cc6f3..9ebeba8a3 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -15,6 +15,9 @@
package inet
import (
+ "bytes"
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
return nil
}
+// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr.
+func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ interfaceAddrs, ok := s.InterfaceAddrsMap[idx]
+ if !ok {
+ return fmt.Errorf("unknown idx: %d", idx)
+ }
+
+ var filteredAddrs []InterfaceAddr
+ for _, interfaceAddr := range interfaceAddrs {
+ if !bytes.Equal(interfaceAddr.Addr, addr.Addr) {
+ filteredAddrs = append(filteredAddrs, addr)
+ }
+ }
+ s.InterfaceAddrsMap[idx] = filteredAddrs
+
+ return nil
+}
+
// SupportsIPv6 implements Stack.SupportsIPv6.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c0de72eef..90dd4a047 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -79,7 +79,7 @@ go_template_instance(
out = "fd_table_refs.go",
package = "kernel",
prefix = "FDTable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FDTable",
},
@@ -90,7 +90,7 @@ go_template_instance(
out = "fs_context_refs.go",
package = "kernel",
prefix = "FSContext",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FSContext",
},
@@ -101,7 +101,7 @@ go_template_instance(
out = "ipc_namespace_refs.go",
package = "kernel",
prefix = "IPCNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "IPCNamespace",
},
@@ -112,7 +112,7 @@ go_template_instance(
out = "process_group_refs.go",
package = "kernel",
prefix = "ProcessGroup",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ProcessGroup",
},
@@ -123,7 +123,7 @@ go_template_instance(
out = "session_refs.go",
package = "kernel",
prefix = "Session",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Session",
},
@@ -229,7 +229,7 @@ go_library(
"//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/secio",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 1b9721534..0ddbe5ff6 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -19,7 +19,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs_vfs2"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -27,7 +27,7 @@ import (
// +stateify savable
type abstractEndpoint struct {
ep transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
name string
ns *AbstractSocketNamespace
}
@@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace {
// its backing socket.
type boundEndpoint struct {
transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
}
// Release implements transport.BoundEndpoint.Release.
@@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
//
// When the last reference managed by socket is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
@@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran
// Remove removes the specified socket at name from the abstract socket
// namespace, if it has not yet been replaced.
-func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) {
+func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) {
a.mu.Lock()
defer a.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 0ec7344cd..7aba31587 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
- f.init() // Initialize table.
+ f.initNoLeakCheck() // Initialize table.
f.used = 0
for fd, d := range m {
if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
@@ -240,6 +240,10 @@ func (f *FDTable) String() string {
case fileVFS2 != nil:
vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
+ vd := fileVFS2.VirtualDentry()
+ if vd.Dentry() == nil {
+ panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2))
+ }
name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
if err != nil {
fmt.Fprintf(&buf, "<err: %v>\n", err)
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index da79e6627..3476551f3 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -31,14 +31,21 @@ type descriptorTable struct {
slice unsafe.Pointer `state:".(map[int32]*descriptor)"`
}
-// init initializes the table.
+// initNoLeakCheck initializes the table without enabling leak checking.
//
-// TODO(gvisor.dev/1486): Enable leak check for FDTable.
-func (f *FDTable) init() {
+// This is used when loading an FDTable after S/R, during which the ref count
+// object itself will enable leak checking if necessary.
+func (f *FDTable) initNoLeakCheck() {
var slice []unsafe.Pointer // Empty slice.
atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
}
+// init initializes the table with leak checking.
+func (f *FDTable) init() {
+ f.initNoLeakCheck()
+ f.EnableLeakCheck()
+}
+
// get gets a file entry.
//
// The boolean indicates whether this was in range.
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index d46d1e1c1..41fb2a784 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext {
f.root.IncRef()
}
- return &FSContext{
+ ctx := &FSContext{
cwd: f.cwd,
root: f.root,
cwdVFS2: f.cwdVFS2,
rootVFS2: f.rootVFS2,
umask: f.umask,
}
+ ctx.EnableLeakCheck()
+ return ctx
}
// WorkingDirectory returns the current working directory.
@@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
+ if f.cwd != nil {
+ f.cwd.IncRef()
+ }
return f.cwd
}
// WorkingDirectoryVFS2 returns the current working directory.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwdVFS2.IncRef()
+ if f.cwdVFS2.Ok() {
+ f.cwdVFS2.IncRef()
+ }
return f.cwdVFS2
}
@@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
// RootDirectoryVFS2 returns the current filesystem root.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.rootVFS2.IncRef()
+ if f.rootVFS2.Ok() {
+ f.rootVFS2.IncRef()
+ }
return f.rootVFS2
}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
index 3f34ee0db..b87e40dd1 100644
--- a/pkg/sentry/kernel/ipc_namespace.go
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry {
return i.shms
}
-// DecRef implements refs_vfs2.RefCounter.DecRef.
+// DecRef implements refsvfs2.RefCounter.DecRef.
func (i *IPCNamespace) DecRef(ctx context.Context) {
i.IPCNamespaceRefs.DecRef(func() {
i.shms.Release(ctx)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 0eb2bf7bd..9b2be44d4 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -430,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w wire.Writer) error {
+func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
saveStart := time.Now()
- ctx := k.SupervisorContext()
// Do not allow other Kernel methods to affect it while it's being saved.
k.extMu.Lock()
@@ -446,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
k.mf.StartEvictions()
k.mf.WaitForEvictions()
- // Flush write operations on open files so data reaches backing storage.
- // This must come after MemoryFile eviction since eviction may cause file
- // writes.
- if err := k.tasks.flushWritesToFiles(ctx); err != nil {
- return err
- }
+ if VFS2Enabled {
+ // Discard unsavable mappings, such as those for host file descriptors.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
- // Remove all epoll waiter objects from underlying wait queues.
- // NOTE: for programs to resume execution in future snapshot scenarios,
- // we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters(ctx)
+ // Prepare filesystems for saving. This must be done after
+ // invalidateUnsavableMappings(), since dropping memory mappings may
+ // affect filesystem state (e.g. page cache reference counts).
+ if err := k.vfs.PrepareSave(ctx); err != nil {
+ return err
+ }
+ } else {
+ // Flush cached file writes to backing storage. This must come after
+ // MemoryFile eviction since eviction may cause file writes.
+ if err := k.flushWritesToFiles(ctx); err != nil {
+ return err
+ }
- // Clear the dirent cache before saving because Dirents must be Loaded in a
- // particular order (parents before children), and Loading dirents from a cache
- // breaks that order.
- if err := k.flushMountSourceRefs(ctx); err != nil {
- return err
- }
+ // Remove all epoll waiter objects from underlying wait queues.
+ // NOTE: for programs to resume execution in future snapshot scenarios,
+ // we will need to re-establish these waiter objects after saving.
+ k.tasks.unregisterEpollWaiters(ctx)
- // Ensure that all inode and mount release operations have completed.
- fs.AsyncBarrier()
+ // Clear the dirent cache before saving because Dirents must be Loaded in a
+ // particular order (parents before children), and Loading dirents from a cache
+ // breaks that order.
+ if err := k.flushMountSourceRefs(ctx); err != nil {
+ return err
+ }
- // Once all fs work has completed (flushed references have all been released),
- // reset mount mappings. This allows individual mounts to save how inodes map
- // to filesystem resources. Without this, fs.Inodes cannot be restored.
- fs.SaveInodeMappings()
+ // Ensure that all inode and mount release operations have completed.
+ fs.AsyncBarrier()
- // Discard unsavable mappings, such as those for host file descriptors.
- // This must be done after waiting for "asynchronous fs work", which
- // includes async I/O that may touch application memory.
- if err := k.invalidateUnsavableMappings(ctx); err != nil {
- return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ // Once all fs work has completed (flushed references have all been released),
+ // reset mount mappings. This allows individual mounts to save how inodes map
+ // to filesystem resources. Without this, fs.Inodes cannot be restored.
+ fs.SaveInodeMappings()
+
+ // Discard unsavable mappings, such as those for host file descriptors.
+ // This must be done after waiting for "asynchronous fs work", which
+ // includes async I/O that may touch application memory.
+ //
+ // TODO(gvisor.dev/issue/1624): This rationale is believed to be
+ // obsolete since AIO callbacks are now waited-for by Kernel.Pause(),
+ // but this order is conservatively retained for VFS1.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
}
// Save the CPUID FeatureSet before the rest of the kernel so we can
@@ -486,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
+ if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- stats, err := state.Save(k.SupervisorContext(), w, k)
+ stats, err := state.Save(ctx, w, k)
if err != nil {
return err
}
@@ -502,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
+ if err := k.mf.SaveTo(ctx, w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -514,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
+//
+// Preconditions: !VFS2Enabled.
func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
- if VFS2Enabled {
- return nil // Not relevant.
- }
-
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -561,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi
return err
}
-func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return nil
- }
-
- return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
+// Preconditions: !VFS2Enabled.
+func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
+ return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
}
@@ -589,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
})
}
-// Preconditions: The kernel must be paused.
-func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
- invalidated := make(map[*mm.MemoryManager]struct{})
- k.tasks.mu.RLock()
- defer k.tasks.mu.RUnlock()
- for t := range k.tasks.Root.tids {
- // We can skip locking Task.mu here since the kernel is paused.
- if mm := t.tc.MemoryManager; mm != nil {
- if _, ok := invalidated[mm]; !ok {
- if err := mm.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- invalidated[mm] = struct{}{}
- }
- }
- // I really wish we just had a sync.Map of all MMs...
- if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
- if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
+// Preconditions: !VFS2Enabled.
func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return
- }
-
ts.mu.RLock()
defer ts.mu.RUnlock()
@@ -644,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
}
}
+// Preconditions: The kernel must be paused.
+func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
+ invalidated := make(map[*mm.MemoryManager]struct{})
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t := range k.tasks.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if mm := t.tc.MemoryManager; mm != nil {
+ if _, ok := invalidated[mm]; !ok {
+ if err := mm.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ invalidated[mm] = struct{}{}
+ }
+ }
+ // I really wish we just had a sync.Map of all MMs...
+ if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
+ if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -656,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
+ if _, err := state.Load(ctx, r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -671,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the kernel state.
kernelStart := time.Now()
- stats, err := state.Load(k.SupervisorContext(), r, k)
+ stats, err := state.Load(ctx, r, k)
if err != nil {
return err
}
@@ -684,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
+ if err := k.mf.LoadFrom(ctx, r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -696,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
net.Resume()
}
- // Ensure that all pending asynchronous work is complete:
- // - namedpipe opening
- // - inode file opening
- if err := fs.AsyncErrorBarrier(); err != nil {
- return err
+ if VFS2Enabled {
+ if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil {
+ return err
+ }
+ } else {
+ // Ensure that all pending asynchronous work is complete:
+ // - namedpipe opening
+ // - inode file opening
+ if err := fs.AsyncErrorBarrier(); err != nil {
+ return err
+ }
}
tcpip.AsyncLoading.Wait()
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index ce0db5583..d6fb0fdb8 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
type sleeper struct {
@@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag
d := fs.NewDirent(ctx, inode, "pipe")
file, err := n.GetFile(ctx, d, flags)
if err != nil {
- t.Fatalf("open with flags %+v failed: %v", flags, err)
+ t.Errorf("open with flags %+v failed: %v", flags, err)
+ return nil, err
}
if doneChan != nil {
doneChan <- struct{}{}
@@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.
}
func newNamedPipe(t *testing.T) *Pipe {
- return NewPipe(true, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(true, DefaultPipeSize)
}
func newAnonPipe(t *testing.T) *Pipe {
- return NewPipe(false, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(false, DefaultPipeSize)
}
// assertRecvBlocks ensures that a recv attempt on c blocks for at least
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 67beb0ad6..b989e14c7 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -26,18 +26,27 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// MinimumPipeSize is a hard limit of the minimum size of a pipe.
- MinimumPipeSize = 64 << 10
+ // It corresponds to fs/pipe.c:pipe_min_size.
+ MinimumPipeSize = usermem.PageSize
+
+ // MaximumPipeSize is a hard limit on the maximum size of a pipe.
+ // It corresponds to fs/pipe.c:pipe_max_size.
+ MaximumPipeSize = 1048576
// DefaultPipeSize is the system-wide default size of a pipe in bytes.
- DefaultPipeSize = MinimumPipeSize
+ // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS.
+ DefaultPipeSize = 16 * usermem.PageSize
- // MaximumPipeSize is a hard limit on the maximum size of a pipe.
- MaximumPipeSize = 8 << 20
+ // atomicIOBytes is the maximum number of bytes that the pipe will
+ // guarantee atomic reads or writes atomically.
+ // It corresponds to limits.h:PIPE_BUF.
+ atomicIOBytes = 4096
)
// Pipe is an encapsulation of a platform-independent pipe.
@@ -53,12 +62,6 @@ type Pipe struct {
// This value is immutable.
isNamed bool
- // atomicIOBytes is the maximum number of bytes that the pipe will
- // guarantee atomic reads or writes atomically.
- //
- // This value is immutable.
- atomicIOBytes int64
-
// The number of active readers for this pipe.
//
// Access atomically.
@@ -94,47 +97,34 @@ type Pipe struct {
// NewPipe initializes and returns a pipe.
//
-// N.B. The size and atomicIOBytes will be bounded.
-func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+// N.B. The size will be bounded.
+func NewPipe(isNamed bool, sizeBytes int64) *Pipe {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
var p Pipe
- initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&p, isNamed, sizeBytes)
return &p
}
-func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
pipe.isNamed = isNamed
pipe.max = sizeBytes
- pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
// representing the read and write ends of the pipe.
-func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
- p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes)
+func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(false /* isNamed */, sizeBytes)
// Build an fs.Dirent for the pipe which will be shared by both
// returned files.
@@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
wanted := ops.left()
avail := p.max - p.view.Size()
if wanted > avail {
- if wanted <= p.atomicIOBytes {
+ if wanted <= atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index fe97e9800..3dd739080 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -26,7 +26,7 @@ import (
func TestPipeRW(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) {
func TestPipeReadBlock(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) {
const capacity = MinimumPipeSize
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, capacity)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) {
const atomicIOBytes = 2
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, atomicIOBytes)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) {
}
}
if err != nil {
- t.Fatalf("Readv: got unexpected error %v", err)
+ t.Errorf("Readv: got unexpected error %v", err)
+ return
}
}
}()
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 1a152142b..7b23cbe86 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -33,6 +33,8 @@ import (
// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
// not be copied.
+//
+// +stateify savable
type VFSPipe struct {
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
@@ -52,9 +54,9 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes)
return &vp
}
@@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l
// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
// other FileDescriptions for splice(2) and tee(2).
+//
+// +stateify savable
type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1145faf13..1abfe2201 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
// at the address specified by the data parameter, and the return value
// is the error flag." - ptrace(2)
word := t.Arch().Native(0)
- if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
+ if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
return err
}
_, err := word.CopyOut(t, data)
@@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
word := t.Arch().Native(uintptr(data))
- _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr)
+ _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr)
return err
case linux.PTRACE_GETREGSET:
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index c00fa1138..b99c0bffa 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -103,6 +103,7 @@ type waiter struct {
waiterEntry
// value represents how much resource the waiter needs to wake up.
+ // The value is either 0 or negative.
value int16
ch chan struct{}
}
@@ -283,6 +284,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File
return nil
}
+// GetStat extracts semid_ds information from the set.
+func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ ds := &linux.SemidDS{
+ SemPerm: linux.IPCPerm{
+ Key: uint32(s.key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
+ Mode: uint16(s.perms.LinuxMode()),
+ Seq: 0, // IPC sequence not supported.
+ },
+ SemOTime: s.opTime.TimeT(),
+ SemCTime: s.changeTime.TimeT(),
+ SemNSems: uint64(s.Size()),
+ }
+ return ds, nil
+}
+
// SetVal overrides a semaphore value, waking up waiters as needed.
func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error {
if val < 0 || val > valueMax {
@@ -320,7 +348,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
}
for _, val := range vals {
- if val < 0 || val > valueMax {
+ if val > valueMax {
return syserror.ERANGE
}
}
@@ -396,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
return sem.pid, nil
}
+func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // The calling process must have read permission on the semaphore set.
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ var cnt uint16
+ for w := sem.waiters.Front(); w != nil; w = w.Next() {
+ if pred(w) {
+ cnt++
+ }
+ }
+ return cnt, nil
+}
+
+// CountZeroWaiters returns number of waiters waiting for the sem's value to increase.
+func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value == 0
+ })
+}
+
+// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero.
+func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value < 0
+ })
+}
+
// ExecuteOps attempts to execute a list of operations to the set. It only
// succeeds when all operations can be applied. No changes are made if it fails.
//
@@ -548,11 +612,18 @@ func (s *Set) destroy() {
}
}
+func abs(val int16) int16 {
+ if val < 0 {
+ return -val
+ }
+ return val
+}
+
// wakeWaiters goes over all waiters and checks which of them can be notified.
func (s *sem) wakeWaiters() {
// Note that this will release all waiters waiting for 0 too.
for w := s.waiters.Front(); w != nil; {
- if s.value < w.value {
+ if s.value < abs(w.value) {
// Still blocked, skip it.
w = w.Next()
continue
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index df5c8421b..5bddb0a36 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session {
//
// If this group isn't visible in this namespace, zero will be returned. It is
// the callers responsibility to check that before using this function.
-func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sids[s]
+func (ns *PIDNamespace) IDOfSession(s *Session) SessionID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sids[s]
}
// SessionWithID returns the Session with the given ID in the PID namespace ns,
// or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the session.
-func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sessions[id]
+func (ns *PIDNamespace) SessionWithID(id SessionID) *Session {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sessions[id]
}
// ProcessGroup returns the ThreadGroup's ProcessGroup.
@@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup {
// IDOfProcessGroup returns the process group assigned to pg in PID namespace ns.
//
// The same constraints apply as IDOfSession.
-func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.pgids[pg]
+func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.pgids[pg]
}
// ProcessGroupWithID returns the ProcessGroup with the given ID in the PID
// namespace ns, or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the process group.
-func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.processGroups[id]
+func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.processGroups[id]
}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index f8a382fd8..80a592c8f 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "shm_refs.go",
package = "shm",
prefix = "Shm",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Shm",
},
@@ -27,7 +27,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 682080c14..527344162 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
}
if opts.ChildSetTID {
ctid := nt.ThreadID()
- ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
+ ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
}
ntid := t.tg.pidns.IDOfTask(nt)
if opts.ParentSetTID {
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index ce134bf54..94dabbcd8 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,7 +18,8 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp
}, nil
}
-// copyContext implements marshal.CopyContext. It wraps a task to allow copying
-// memory to and from the task memory with custom usermem.IOOpts.
-type copyContext struct {
- *Task
+type taskCopyContext struct {
+ ctx context.Context
+ t *Task
opts usermem.IOOpts
}
-// AsCopyContext wraps the task and returns it as CopyContext.
-func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext {
- return &copyContext{t, opts}
+// CopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts.
+func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext {
+ return &taskCopyContext{
+ ctx: ctx,
+ t: t,
+ opts: opts,
+ }
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte {
+ if ctxTask, ok := cc.ctx.(*Task); ok {
+ return ctxTask.CopyScratchBuffer(size)
+ }
+ return make([]byte, size)
+}
+
+func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) {
+ cc.t.mu.Lock()
+ tmm := cc.t.MemoryManager()
+ cc.t.mu.Unlock()
+ if !tmm.IncUsers() {
+ return nil, syserror.EFAULT
+ }
+ return tmm, nil
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyIn(cc.ctx, addr, dst, cc.opts)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyOut(cc.ctx, addr, src, cc.opts)
+}
+
+type ownTaskCopyContext struct {
+ t *Task
+ opts usermem.IOOpts
+}
+
+// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts. The returned CopyContext may only be used by t's task
+// goroutine.
+//
+// Since t already implements marshal.CopyContext, this is only needed to
+// override the usermem.IOOpts used for the copy.
+func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext {
+ return &ownTaskCopyContext{
+ t: t,
+ opts: opts,
+ }
}
-// CopyInString copies a string in from the task's memory.
-func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
- return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts)
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte {
+ return cc.t.CopyScratchBuffer(size)
}
-// CopyInBytes copies task memory into dst from an IO context.
-func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
- return t.MemoryManager().CopyIn(t, addr, dst, t.opts)
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts)
}
-// CopyOutBytes copies src into task memoryfrom an IO context.
-func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
- return t.MemoryManager().CopyOut(t, addr, src, t.opts)
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts)
}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index 9bc452e67..9e5c2d26f 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error {
}
if old != v.seq {
- return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq)
+ return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq)
}
v.seq = next
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b4a47ccca..6dbeccfe2 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -78,7 +78,7 @@ go_template_instance(
out = "aio_mappable_refs.go",
package = "mm",
prefix = "aioMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "aioMappable",
},
@@ -89,7 +89,7 @@ go_template_instance(
out = "special_mappable_refs.go",
package = "mm",
prefix = "SpecialMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SpecialMappable",
},
@@ -127,6 +127,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safecopy",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 0a54dd30d..acad4c793 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) {
c.runData.requestInterruptWindow = 0
}
+// bluepillSigBus is reponsible for injecting NMI to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_NMI, 0); errno != 0 {
+ throw("NMI injection failed")
+ }
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 58f3d6fdd..965ad66b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -27,15 +27,20 @@ var (
// The action for bluepillSignal is changed by sigaction().
bluepillSignal = syscall.SIGILL
- // vcpuSErr is the event of system error.
- vcpuSErr = kvmVcpuEvents{
+ // vcpuSErrBounce is the event of system error for bouncing KVM.
+ vcpuSErrBounce = kvmVcpuEvents{
exception: exception{
sErrPending: 1,
- sErrHasEsr: 0,
- pad: [6]uint8{0, 0, 0, 0, 0, 0},
- sErrEsr: 1,
},
- rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
+
+ // vcpuSErrNMI is the event of system error to trigger sigbus.
+ vcpuSErrNMI = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 1,
+ sErrEsr: _ESR_ELx_SERR_NMI,
+ },
}
)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index b35c930e2..9433d4da5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int {
//
//go:nosplit
func bluepillStopGuest(c *vCPU) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillSigBus is reponsible for injecting sError to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
throw("sErr injection failed")
}
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index eb05950cd..75085ac6a 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall( // escapes: no.
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_NMI, 0); errno != 0 {
- throw("NMI injection failed")
- }
+ bluepillSigBus(c)
continue // Rerun vCPU.
default:
throw("run failed")
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index dd45ad10b..5979aef97 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
// NewAddressSpace returns a new pagetable root.
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
- pageTables := pagetables.New(newAllocator())
- k.machine.mapUpperHalf(pageTables)
+ pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 84df0f878..b060d9544 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -38,6 +38,8 @@ const (
_KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
_KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
_KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+ _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a
+ _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00
)
// Arm64: Architectural Feature Access Control Register EL1.
@@ -149,6 +151,9 @@ const (
_ESR_SEGV_PEMERR_L1 = 0xd
_ESR_SEGV_PEMERR_L2 = 0xe
_ESR_SEGV_PEMERR_L3 = 0xf
+
+ // Custom ISS field definitions for system error.
+ _ESR_ELx_SERR_NMI = 0x1
)
// Arm64: MMIO base address used to dispatch hypercalls.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 61ed24d01..e2fffc99b 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,6 +41,9 @@ type machine struct {
// slots are currently being updated, and the caller should retry.
nextSlot uint32
+ // upperSharedPageTables tracks the read-only shared upper of all the pagetables.
+ upperSharedPageTables *pagetables.PageTables
+
// kernel is the set of global structures.
kernel ring0.Kernel
@@ -198,9 +202,7 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- }, m.maxVCPUs)
+ m.kernel.Init(m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -212,6 +214,13 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of slots is %d.", m.maxSlots)
m.usedSlots = make([]uintptr, m.maxSlots)
+ // Create the upper shared pagetables and kernel(sentry) pagetables.
+ m.upperSharedPageTables = pagetables.New(newAllocator())
+ m.mapUpperHalf(m.upperSharedPageTables)
+ m.upperSharedPageTables.Allocator.(*allocator).base.Drain()
+ m.upperSharedPageTables.MarkReadOnlyShared()
+ m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -225,7 +234,6 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -625,3 +633,35 @@ func (c *vCPU) BounceToKernel() {
func (c *vCPU) BounceToHost() {
c.bounce(true)
}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c67127d95..8e03c310d 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error {
}
}
-// setSystemTimeLegacy calibrates and sets an approximate system time.
-func (c *vCPU) setSystemTimeLegacy() error {
- const minIterations = 10
- minimum := uint64(0)
- for iter := 0; ; iter++ {
- // Try to set the TSC to an estimate of where it will be
- // on the host during a "fast" system call iteration.
- start := uint64(ktime.Rdtsc())
- if err := c.setTSC(start + (minimum / 2)); err != nil {
- return err
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result.
- end := uint64(ktime.Rdtsc())
- if end < start {
- continue // Totally bogus: unstable TSC?
- }
- current := end - start
- if current < minimum || iter == 0 {
- minimum = current // Set our new minimum.
- }
- // Is this past minIterations and within ~10% of minimum?
- upperThreshold := (((minimum << 3) + minimum) >> 3)
- if iter >= minIterations && current <= upperThreshold {
- return nil
- }
- }
-}
-
// nonCanonical generates a canonical address return.
//
//go:nosplit
@@ -464,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
-var execRegions = func() (regions []region) {
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ // Map all the executible regions so that all the entry functions
+ // are mapped in the upper half.
applyVirtualRegions(func(vr virtualRegion) {
if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
return
}
+
if vr.accessType.Execute {
- regions = append(regions, vr.region)
+ r := vr.region
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
}
})
- return
-}()
-
-func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
- for _, r := range execRegions {
- physical, length, ok := translateToPhysical(r.virtual)
- if !ok || length < r.length {
- panic("impossilbe translation")
- }
- pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|r.virtual),
- r.length,
- pagetables.MapOpts{AccessType: usermem.Execute},
- physical)
- }
for start, end := range m.kernel.EntryRegions() {
regionLen := end - start
physical, length, ok := translateToPhysical(start)
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index a163f956d..fd92c3873 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error {
}
c.floatingPointState = arch.NewFloatingPointData()
+
+ return c.setSystemTime()
+}
+
+// setTSC sets the counter Virtual Offset.
+func (c *vCPU) setTSC(value uint64) error {
+ var (
+ reg kvmOneReg
+ data uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ reg.id = _KVM_ARM64_REGS_TIMER_CNT
+ data = uint64(value)
+
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
return nil
}
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ return c.setSystemTimeLegacy()
+}
+
//go:nosplit
func (c *vCPU) loadSegments(tid uint64) {
// TODO(gvisor.dev/issue/1238): TLS is not supported.
@@ -197,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
} else if !ring0.IsCanonical(regs.Sp) {
- return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info)
}
// Assign PCIDs.
@@ -233,10 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.PageFault:
return c.fault(int32(syscall.SIGSEGV), info)
+ case ring0.El0ErrNMI:
+ return c.fault(int32(syscall.SIGBUS), info)
case ring0.Vector(bounce): // ring0.VirtualizationException
return usermem.NoAccess, platform.ErrContextInterrupt
- case ring0.El0Sync_undef,
- ring0.El1Sync_undef:
+ case ring0.El0SyncUndef:
+ return c.fault(int32(syscall.SIGILL), info)
+ case ring0.El1SyncUndef:
*info = arch.SignalInfo{
Signo: int32(syscall.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 87a573cc4..327d48465 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -58,46 +58,55 @@ type Vector uintptr
// Exception vectors.
const (
- El1SyncInvalid = iota
- El1IrqInvalid
- El1FiqInvalid
- El1ErrorInvalid
+ El1InvSync = iota
+ El1InvIrq
+ El1InvFiq
+ El1InvError
+
El1Sync
El1Irq
El1Fiq
- El1Error
+ El1Err
+
El0Sync
El0Irq
El0Fiq
- El0Error
- El0Sync_invalid
- El0Irq_invalid
- El0Fiq_invalid
- El0Error_invalid
- El1Sync_da
- El1Sync_ia
- El1Sync_sp_pc
- El1Sync_undef
- El1Sync_dbg
- El1Sync_inv
- El0Sync_svc
- El0Sync_da
- El0Sync_ia
- El0Sync_fpsimd_acc
- El0Sync_sve_acc
- El0Sync_sys
- El0Sync_sp_pc
- El0Sync_undef
- El0Sync_dbg
- El0Sync_inv
+ El0Err
+
+ El0InvSync
+ El0InvIrq
+ El0InvFiq
+ El0InvErr
+
+ El1SyncDa
+ El1SyncIa
+ El1SyncSpPc
+ El1SyncUndef
+ El1SyncDbg
+ El1SyncInv
+
+ El0SyncSVC
+ El0SyncDa
+ El0SyncIa
+ El0SyncFpsimdAcc
+ El0SyncSveAcc
+ El0SyncSys
+ El0SyncSpPc
+ El0SyncUndef
+ El0SyncDbg
+ El0SyncInv
+
+ El0ErrNMI
+ El0ErrBounce
+
_NR_INTERRUPTS
)
// System call vectors.
const (
- Syscall Vector = El0Sync_svc
- PageFault Vector = El0Sync_da
- VirtualizationException Vector = El0Error
+ Syscall Vector = El0SyncSVC
+ PageFault Vector = El0SyncDa
+ VirtualizationException Vector = El0ErrBounce
)
// VirtualAddressBits returns the number bits available for virtual addresses.
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index e6daf24df..f9765771e 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -23,6 +23,9 @@ import (
//
// This contains global state, shared by multiple CPUs.
type Kernel struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+
KernelArchState
}
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 00899273e..7a2275558 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -66,17 +66,9 @@ var (
KernelDataSegment SegmentDescriptor
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
-
- // cpuEntries is array of kernelEntry for all cpus
+ // cpuEntries is array of kernelEntry for all cpus.
cpuEntries []kernelEntry
// globalIDT is our set of interrupt gates.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index 508236e46..a014dcbc0 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -32,15 +32,8 @@ var (
KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
}
// CPUArchState contains CPU-specific arch state.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 2370a9276..f489ad352 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -288,6 +288,10 @@
#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+/* ISS field definitions for system error */
+#define ESR_ELx_SERR_MASK (0x1)
+#define ESR_ELx_SERR_NMI (0x1)
+
// LOAD_KERNEL_ADDRESS loads a kernel address.
#define LOAD_KERNEL_ADDRESS(from, to) \
MOVD from, to; \
@@ -366,6 +370,19 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// EXCEPTION_WITH_ERROR is a common exception handler function.
+#define EXCEPTION_WITH_ERROR(user, vector) \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ WORD $0xd538601a; \ //MRS FAR_EL1, R26
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
+ MOVD $user, R3; \
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
+ MOVD $vector, R3; \
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
+ MRS ESR_EL1, R3; \
+ MOVD R3, CPU_ERROR_CODE(RSV_REG); \
+ B ·kernelExitToEl1(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -503,6 +520,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
MSR R1, ELR_EL1
+ // restore sentry's tls.
+ MOVD CPU_REGISTERS+PTRACE_TLS(RSV_REG), R1
+ MSR R1, TPIDR_EL0
+
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
@@ -659,21 +680,7 @@ el0_svc:
el0_da:
el0_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $1, R3
- MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- MRS ESR_EL1, R3
- MOVD R3, CPU_ERROR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, PageFault)
el0_fpsimd_acc:
B ·Shutdown(SB)
@@ -688,10 +695,7 @@ el0_sp_pc:
B ·Shutdown(SB)
el0_undef:
- MOVD $El0Sync_undef, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, El0SyncUndef)
el0_dbg:
B ·Shutdown(SB)
@@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0
TEXT ·El0_error(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ AND $ESR_ELx_SERR_MASK, R25, R24
+ CMP $ESR_ELx_SERR_NMI, R24
+ BEQ el0_nmi
+ B el0_bounce
+el0_nmi:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $El0ErrNMI, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ MRS ESR_EL1, R3
+ MOVD R3, CPU_ERROR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_bounce:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
WORD $0xd538601a //MRS FAR_EL1, R26
@@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0
MOVD $VirtualizationException, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·HaltAndResume(SB)
+ B ·kernelExitToEl1(SB)
TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 264be23d3..292f9d0cc 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -16,11 +16,9 @@ package ring0
// Init initializes a new kernel.
//
-// N.B. that constraints on KernelOpts must be satisfied.
-//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
- k.init(opts, maxCPUs)
+func (k *Kernel) Init(maxCPUs int) {
+ k.init(maxCPUs)
}
// Halt halts execution.
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 3a9dff4cc..b55dc29b3 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -24,10 +24,7 @@ import (
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
-
+func (k *Kernel) init(maxCPUs int) {
entrySize := reflect.TypeOf(kernelEntry{}).Size()
var (
entries []kernelEntry
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index b294ccc7c..6cbbf001f 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,9 +25,7 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
+func (k *Kernel) init(maxCPUs int) {
}
// init initializes architecture-specific state.
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 45eba960d..53bc3353c 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -47,43 +47,36 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
- fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
- fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
- fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
- fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
- fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+ fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err)
fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
- fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+ fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err)
- fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
- fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
- fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
- fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+ fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa)
+ fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa)
+ fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc)
+ fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef)
+ fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg)
+ fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv)
- fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
- fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
- fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
- fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
- fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
- fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+ fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC)
+ fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa)
+ fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
+ fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
+ fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
+ fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
+ fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
+ fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg)
+ fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv)
- fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
- fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
- fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
- fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
- fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
- fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
- fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
- fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
- fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
- fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+ fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI)
fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 7f18ac296..bc16a1622 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -30,6 +30,10 @@ type PageTables struct {
Allocator Allocator
// root is the pagetable root.
+ //
+ // For same archs such as amd64, the upper of the PTEs is cloned
+ // from and owned by upperSharedPageTables which are shared among
+ // many PageTables if upperSharedPageTables is not nil.
root *PTEs
// rootPhysical is the cached physical address of the root.
@@ -39,15 +43,52 @@ type PageTables struct {
// archPageTables includes architecture-specific features.
archPageTables
+
+ // upperSharedPageTables represents a read-only shared upper
+ // of the Pagetable. When it is not nil, the upper is not
+ // allowed to be modified.
+ upperSharedPageTables *PageTables
+
+ // upperStart is the start address of the upper portion that
+ // are shared from upperSharedPageTables
+ upperStart uintptr
+
+ // readOnlyShared indicates the Pagetables are read-only and
+ // own the ranges that are shared with other Pagetables.
+ readOnlyShared bool
}
-// New returns new PageTables.
-func New(a Allocator) *PageTables {
+// NewWithUpper returns new PageTables.
+//
+// upperSharedPageTables are used for mapping the upper of addresses,
+// starting at upperStart. These pageTables should not be touched (as
+// invalidations may be incorrect) after they are passed as an
+// upperSharedPageTables. Only when all dependent PageTables are gone
+// may they be used. The intenteded use case is for kernel page tables,
+// which are static and fixed.
+//
+// Precondition: upperStart must be between canonical ranges.
+// Precondition: upperStart must be pgdSize aligned.
+// precondition: upperSharedPageTables must be marked read-only shared.
+func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables {
p := new(PageTables)
p.Init(a)
+ if upperSharedPageTables != nil {
+ if !upperSharedPageTables.readOnlyShared {
+ panic("Only read-only shared pagetables can be used as upper")
+ }
+ p.upperSharedPageTables = upperSharedPageTables
+ p.upperStart = upperStart
+ p.cloneUpperShared()
+ }
return p
}
+// New returns new PageTables.
+func New(a Allocator) *PageTables {
+ return NewWithUpper(a, nil, 0)
+}
+
// mapVisitor is used for map.
type mapVisitor struct {
target uintptr // Input.
@@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
if !opts.AccessType.Any() {
return p.Unmap(addr, length)
}
@@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// True is returned iff there was a previous mapping in the range.
//
-// Precondition: addr & length must be page-aligned.
+// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
// +checkescape:hard,stack
//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
w := unmapWalker{
pageTables: p,
visitor: unmapVisitor{
@@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts)
w.iterateRange(uintptr(addr), uintptr(addr)+1)
return w.visitor.physical + offset, w.visitor.opts
}
+
+// MarkReadOnlyShared marks the pagetables read-only and can be shared.
+//
+// It is usually used on the pagetables that are used as the upper
+func (p *PageTables) MarkReadOnlyShared() {
+ p.readOnlyShared = true
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 520161755..a4e416af7 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -24,14 +24,6 @@ import (
// archPageTables is architecture-specific data.
type archPageTables struct {
- // root is the pagetable root for kernel space.
- root *PTEs
-
- // rootPhysical is the cached physical address of the root.
- //
- // This is saved only to prevent constant translation.
- rootPhysical uintptr
-
asid uint16
}
@@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
//
//go:nosplit
func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
- return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+ return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
}
// Bits in page table entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index 0c153cf8c..e7ab887e5 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) {
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
}
+func pgdIndex(upperStart uintptr) uintptr {
+ if upperStart&(pgdSize-1) != 0 {
+ panic("upperStart should be pgd size aligned")
+ }
+ if upperStart >= upperBottom {
+ return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize
+ }
+ if upperStart < lowerTop {
+ return upperStart / pgdSize
+ }
+ panic("upperStart should be in canonical range")
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ start := pgdIndex(p.upperStart)
+ copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage])
+}
+
// PTEs is a collection of entries.
type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 1a49f12a2..5392bf27a 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -36,7 +36,7 @@ const (
pudSize = 1 << pudShift
pgdSize = 1 << pgdShift
- ttbrASIDOffset = 55
+ ttbrASIDOffset = 48
ttbrASIDMask = 0xff
entriesPerPage = 512
@@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) {
p.Allocator = allocator
p.root = p.Allocator.NewPTEs()
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
- p.archPageTables.root = p.Allocator.NewPTEs()
- p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ if p.upperStart != upperBottom {
+ panic("upperStart should be the same as upperBottom")
+ }
+
+ // nothing to do for arm.
}
// PTEs is a collection of entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
index c261d393a..157c9a7cc 100644
--- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr {
func (w *Walker) iterateRangeCanonical(start, end uintptr) {
pgdEntryIndex := w.pageTables.root
if start >= upperBottom {
- pgdEntryIndex = w.pageTables.archPageTables.root
+ pgdEntryIndex = w.pageTables.upperSharedPageTables.root
}
for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index d9621968c..37d02948f 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -24,6 +24,8 @@ import (
)
// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+//
+// +stateify savable
type SCMRightsVFS2 interface {
transport.RightsControlMessage
@@ -34,9 +36,11 @@ type SCMRightsVFS2 interface {
Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
}
-// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
-// maintained for each vfs.FileDescription and is release either when an FD is created or
-// when the Release method is called.
+// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference
+// is maintained for each vfs.FileDescription and is release either when an FD
+// is created or when the Release method is called.
+//
+// +stateify savable
type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 163af329b..9a2cac40b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// +stateify savable
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{})
func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &socketVFS2{
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index faa61160e..7e7857ac3 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return syserror.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+func (s *Stack) SetTCPSACKEnabled(bool) error {
return syserror.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return syserror.EACCES
}
@@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
if rawLine == "" {
- return fmt.Errorf("Failed to get raw line")
+ return fmt.Errorf("failed to get raw line")
}
parts := strings.SplitN(rawLine, ":", 2)
if len(parts) != 2 {
- return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ return fmt.Errorf("failed to get prefix from: %q", rawLine)
}
sliceStat = toSlice(stat)
fields := strings.Fields(strings.TrimSpace(parts[1]))
if len(fields) != len(sliceStat) {
- return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ return fmt.Errorf("failed to parse fields: %q", rawLine)
}
if _, ok := stat.(*inet.StatSNMPTCP); ok {
snmpTCP = true
@@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
}
if err != nil {
- return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err)
}
}
@@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 549787955..e0976fed0 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
// marshalTarget and unmarshalTarget can be used.
type targetMaker interface {
// id uniquely identifies the target.
- id() stack.TargetID
+ id() targetID
- // marshal converts from a stack.Target to an ABI struct.
- marshal(target stack.Target) []byte
+ // marshal converts from a target to an ABI struct.
+ marshal(target target) []byte
- // unmarshal converts from the ABI matcher struct to a stack.Target.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+ // unmarshal converts from the ABI matcher struct to a target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error)
}
-// targetMakers maps the TargetID of supported targets to the targetMaker that
+// A targetID uniquely identifies a target.
+type targetID struct {
+ // name is the target name as stored in the xt_entry_target struct.
+ name string
+
+ // networkProtocol is the protocol to which the target applies.
+ networkProtocol tcpip.NetworkProtocolNumber
+
+ // revision is the version of the target.
+ revision uint8
+}
+
+// target extends a stack.Target, allowing it to be used with the extension
+// system. The sentry only uses targets, never stack.Targets directly.
+type target interface {
+ stack.Target
+ id() targetID
+}
+
+// targetMakers maps the targetID of supported targets to the targetMaker that
// marshals and unmarshals it. It is immutable after package initialization.
-var targetMakers = map[stack.TargetID]targetMaker{}
+var targetMakers = map[targetID]targetMaker{}
func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
- tid := stack.TargetID{
- Name: name,
- NetworkProtocol: netProto,
- Revision: rev,
+ tid := targetID{
+ name: name,
+ networkProtocol: netProto,
+ revision: rev,
}
if _, ok := targetMakers[tid]; !ok {
return 0, false
@@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8
// Return the highest supported revision unless rev is higher.
for _, other := range targetMakers {
otherID := other.id()
- if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
- rev = uint8(otherID.Revision)
+ if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev {
+ rev = uint8(otherID.revision)
}
}
return rev, true
@@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) {
targetMakers[tm.id()] = tm
}
-func marshalTarget(target stack.Target) []byte {
- targetMaker, ok := targetMakers[target.ID()]
+func marshalTarget(tgt stack.Target) []byte {
+ // The sentry only uses targets, never stack.Targets directly.
+ target := tgt.(target)
+ targetMaker, ok := targetMakers[target.id()]
if !ok {
- panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id()))
}
return targetMaker.marshal(target)
}
-func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
- tid := stack.TargetID{
- Name: target.Name.String(),
- NetworkProtocol: filter.NetworkProtocol(),
- Revision: target.Revision,
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) {
+ tid := targetID{
+ name: target.Name.String(),
+ networkProtocol: filter.NetworkProtocol(),
+ revision: target.Revision,
}
targetMaker, ok := targetMakers[tid]
if !ok {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index b560fae0d..70c561cce 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct.
- entries, info := getEntries4(table, tablename)
+ entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 4253f7bf4..5dbb604f0 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct, which is the same in IPv4 and IPv6.
- entries, info := getEntries6(table, tablename)
+ entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 904a12e38..b283d7229 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) {
}
}
+// Table names.
+const (
+ natTable = "nat"
+ mangleTable = "mangle"
+ filterTable = "filter"
+)
+
+// nameToID is immutable.
+var nameToID = map[string]stack.TableID{
+ natTable: stack.NATID,
+ mangleTable: stack.MangleID,
+ filterTable: stack.FilterID,
+}
+
+// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
+// compatibility with netfilter extensions.
+func DefaultLinuxTables() *stack.IPTables {
+ tables := stack.DefaultTables()
+ tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
+ switch val := oldTarget.(type) {
+ case *stack.AcceptTarget:
+ return &acceptTarget{AcceptTarget: *val}
+ case *stack.DropTarget:
+ return &dropTarget{DropTarget: *val}
+ case *stack.ErrorTarget:
+ return &errorTarget{ErrorTarget: *val}
+ case *stack.UserChainTarget:
+ return &userChainTarget{UserChainTarget: *val}
+ case *stack.ReturnTarget:
+ return &returnTarget{ReturnTarget: *val}
+ case *stack.RedirectTarget:
+ return &redirectTarget{RedirectTarget: *val}
+ default:
+ panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val))
+ }
+ })
+ return tables
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
@@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.FilterTable:
+ case filterTable:
table = stack.EmptyFilterTable()
- case stack.NATTable:
+ case natTable:
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
@@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx], ipv6) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx)
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
-
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case *stack.AcceptTarget, *stack.DropTarget:
+ case *acceptTarget, *dropTarget:
return true
default:
return false
@@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(*stack.AcceptTarget)
+ _, ok := rule.Target.(*acceptTarget)
return ok
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 0e14447fe..f2653d523 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -26,6 +26,15 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port and/or IP for packets.
+const RedirectTargetName = "REDIRECT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -52,25 +61,96 @@ func init() {
})
}
+// The stack package provides some basic, useful targets for us. The following
+// types wrap them for compatibility with the extension system.
+
+type acceptTarget struct {
+ stack.AcceptTarget
+}
+
+func (at *acceptTarget) id() targetID {
+ return targetID{
+ networkProtocol: at.NetworkProtocol,
+ }
+}
+
+type dropTarget struct {
+ stack.DropTarget
+}
+
+func (dt *dropTarget) id() targetID {
+ return targetID{
+ networkProtocol: dt.NetworkProtocol,
+ }
+}
+
+type errorTarget struct {
+ stack.ErrorTarget
+}
+
+func (et *errorTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: et.NetworkProtocol,
+ }
+}
+
+type userChainTarget struct {
+ stack.UserChainTarget
+}
+
+func (uc *userChainTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: uc.NetworkProtocol,
+ }
+}
+
+type returnTarget struct {
+ stack.ReturnTarget
+}
+
+func (rt *returnTarget) id() targetID {
+ return targetID{
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
+type redirectTarget struct {
+ stack.RedirectTarget
+
+ // addr must be (un)marshalled when reading and writing the target to
+ // userspace, but does not affect behavior.
+ addr tcpip.Address
+}
+
+func (rt *redirectTarget) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (sm *standardTargetMaker) id() stack.TargetID {
+func (sm *standardTargetMaker) id() targetID {
// Standard targets have the empty string as a name and no revisions.
- return stack.TargetID{
- NetworkProtocol: sm.NetworkProtocol,
+ return targetID{
+ networkProtocol: sm.NetworkProtocol,
}
}
-func (*standardTargetMaker) marshal(target stack.Target) []byte {
+
+func (*standardTargetMaker) marshal(target target) []byte {
// Translate verdicts the same way as the iptables tool.
var verdict int32
switch tg := target.(type) {
- case *stack.AcceptTarget:
+ case *acceptTarget:
verdict = -linux.NF_ACCEPT - 1
- case *stack.DropTarget:
+ case *dropTarget:
verdict = -linux.NF_DROP - 1
- case *stack.ReturnTarget:
+ case *returnTarget:
verdict = linux.NF_RETURN
case *JumpTarget:
verdict = int32(tg.Offset)
@@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTStandardTarget {
nflog("buf has wrong size for standard target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -114,20 +194,20 @@ type errorTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (em *errorTargetMaker) id() stack.TargetID {
+func (em *errorTargetMaker) id() targetID {
// Error targets have no revision.
- return stack.TargetID{
- Name: stack.ErrorTargetName,
- NetworkProtocol: em.NetworkProtocol,
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: em.NetworkProtocol,
}
}
-func (*errorTargetMaker) marshal(target stack.Target) []byte {
+func (*errorTargetMaker) marshal(target target) []byte {
var errorName string
switch tg := target.(type) {
- case *stack.ErrorTarget:
- errorName = stack.ErrorTargetName
- case *stack.UserChainTarget:
+ case *errorTarget:
+ errorName = ErrorTargetName
+ case *userChainTarget:
errorName = tg.Name
default:
panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
@@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte {
},
}
copy(xt.Name[:], errorName)
- copy(xt.Target.Name[:], stack.ErrorTargetName)
+ copy(xt.Target.Name[:], ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTErrorTarget {
nflog("buf has insufficient size for error target %d", len(buf))
return nil, syserr.ErrInvalidArgument
}
- var errorTarget linux.XTErrorTarget
+ var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &errTgt)
// Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named stack.ErrorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
+ // * An actual error case. These rules have an error named
+ // ErrorTargetName. The last entry of the table is usually an error
+ // case to catch any packets that somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case stack.ErrorTargetName:
- return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ switch name := errTgt.Name.String(); name {
+ case ErrorTargetName:
+ return &errorTarget{stack.ErrorTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}, nil
default:
// User defined chain.
- return &stack.UserChainTarget{
+ return &userChainTarget{stack.UserChainTarget{
Name: name,
NetworkProtocol: filter.NetworkProtocol(),
- }, nil
+ }}, nil
}
}
@@ -178,22 +259,22 @@ type redirectTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *redirectTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *redirectTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*redirectTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*redirectTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
// This is a redirect target named redirect
xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(xt.Target.Name[:], stack.RedirectTargetName)
+ copy(xt.Target.Name[:], RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
@@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) < linux.SizeOfXTRedirectTarget {
nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- var redirectTarget linux.XTRedirectTarget
+ var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &rt)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
- target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+ target := redirectTarget{RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
// RangeSize should be 1.
- nfRange := redirectTarget.NfRange
+ nfRange := rt.NfRange
if nfRange.RangeSize != 1 {
nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
return nil, syserr.ErrInvalidArgument
@@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
return &target, nil
@@ -264,15 +347,15 @@ type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *nfNATTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *nfNATTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*nfNATTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarhsalledSize,
@@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
}
- copy(nt.Target.Name[:], stack.RedirectTargetName)
- copy(nt.Range.MinAddr[:], rt.Addr)
- copy(nt.Range.MaxAddr[:], rt.Addr)
+ copy(nt.Target.Name[:], RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.addr)
+ copy(nt.Range.MaxAddr[:], rt.addr)
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
@@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarhsalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
@@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
return nil, syserr.ErrInvalidArgument
}
- target := stack.RedirectTarget{
- NetworkProtocol: filter.NetworkProtocol(),
- Addr: tcpip.Address(natRange.MinAddr[:]),
- Port: ntohs(natRange.MinProto),
+ target := redirectTarget{
+ RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Port: ntohs(natRange.MinProto),
+ },
+ addr: tcpip.Address(natRange.MinAddr[:]),
}
return &target, nil
@@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
+ return &acceptTarget{stack.AcceptTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_DROP - 1:
- return &stack.DropTarget{NetworkProtocol: netProto}, nil
+ return &dropTarget{stack.DropTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_QUEUE - 1:
nflog("unsupported iptables verdict QUEUE")
return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
+ return &returnTarget{stack.ReturnTarget{
+ NetworkProtocol: netProto,
+ }}, nil
default:
nflog("unknown iptables verdict %d", val)
return nil, syserr.ErrInvalidArgument
@@ -382,9 +473,9 @@ type JumpTarget struct {
}
// ID implements Target.ID.
-func (jt *JumpTarget) ID() stack.TargetID {
- return stack.TargetID{
- NetworkProtocol: jt.NetworkProtocol,
+func (jt *JumpTarget) id() targetID {
+ return targetID{
+ networkProtocol: jt.NetworkProtocol,
}
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 844acfede..352c51390 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.TCPProtocolNumber {
- return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber)
}
return &TCPMatcher{
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 63201201c..c88d8268d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
}
return &UDPMatcher{
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
index e8930f031..f061c5d62 100644
--- a/pkg/sentry/socket/netlink/provider_vfs2.go
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol
vfsfd := &s.vfsfd
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index c84d8bd7c..f4d034c13 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -36,9 +36,9 @@ type commandKind int
const (
kindNew commandKind = 0x0
- kindDel = 0x1
- kindGet = 0x2
- kindSet = 0x3
+ kindDel commandKind = 0x1
+ kindGet commandKind = 0x2
+ kindSet commandKind = 0x3
)
func typeKind(typ uint16) commandKind {
@@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
}
attrs = rest
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We add the local interface address here
+ // and ignore the IFA_ADDRESS.
switch ahdr.Type {
case linux.IFA_LOCAL:
err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
@@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
} else if err != nil {
return syserr.ErrInvalidArgument
}
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
}
}
return nil
}
+// delAddr handles RTM_DELADDR requests.
+func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We use the local interface address to
+ // remove the address and ignore the IFA_ADDRESS.
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err != nil {
+ return syserr.ErrBadLocalAddress
+ }
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
hdr := msg.Header()
@@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
return p.newAddr(ctx, msg, ms)
+ case linux.RTM_DELADDR:
+ return p.delAddr(ctx, msg, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index c83b23242..461d524e5 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -37,6 +37,8 @@ import (
// to/from the kernel.
//
// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 211f07947..86c634715 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1244,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
+ case linux.SO_ACCEPTCONN:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.AcceptConnOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 4c6791fff..b0d9e4d9e 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -35,6 +35,8 @@ import (
// SocketVFS2 encapsulates all the state needed to represent a network stack
// endpoint in the kernel context.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
}
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &SocketVFS2{
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 1028d2a6e..fa9ac9059 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
-// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+// convertAddr converts an InterfaceAddr to a ProtocolAddress.
+func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) {
var (
- protocol tcpip.NetworkProtocolNumber
- address tcpip.Address
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ protocolAddress tcpip.ProtocolAddress
)
switch addr.Family {
case linux.AF_INET:
- if len(addr.Addr) < header.IPv4AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv4AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv4.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
-
+ address = tcpip.Address(addr.Addr)
case linux.AF_INET6:
- if len(addr.Addr) < header.IPv6AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv6AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv6AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv6.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
-
+ address = tcpip.Address(addr.Addr)
default:
- return syserror.ENOTSUP
+ return protocolAddress, syserror.ENOTSUP
}
- protocolAddress := tcpip.ProtocolAddress{
+ protocolAddress = tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: int(addr.PrefixLen),
},
}
+ return protocolAddress, nil
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
// Attach address to interface.
- if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network if it doesn't exist already.
+ localRoute := tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: nicID,
+ }
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ if rt.Equal(localRoute) {
+ return nil
+ }
+ }
+
+ // Local route does not exist yet. Add it.
+ s.Stack.AddRoute(localRoute)
+
+ return nil
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
+
+ // Remove addresses matching the address and prefix.
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
- // Add route for local network.
- s.Stack.AddRoute(tcpip.Route{
+ // Remove the corresponding local network route if it exists.
+ localRoute := tcpip.Route{
Destination: protocolAddress.AddressWithPrefix.Subnet(),
Gateway: "", // No gateway for local network.
- NIC: tcpip.NICID(idx),
+ NIC: nicID,
+ }
+ s.Stack.RemoveRoutes(func(rt tcpip.Route) bool {
+ return rt.Equal(localRoute)
})
+
return nil
}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cc7408698..cce0acc33 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "socket_refs.go",
package = "unix",
prefix = "socketOperations",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketOperations",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "socket_vfs2_refs.go",
package = "unix",
prefix = "socketVFS2",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketVFS2",
},
@@ -43,6 +43,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 26c3a51b9..3ebbd28b0 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "queue_refs.go",
package = "transport",
prefix = "queue",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "queue",
},
@@ -44,6 +44,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d6fc03520..b648273a4 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -32,6 +32,8 @@ import (
const initialLimit = 16 * 1024
// A RightsControlMessage is a control message containing FDs.
+//
+// +stateify savable
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
Clone() RightsControlMessage
@@ -336,7 +338,7 @@ type Receiver interface {
RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
- // called before droping all references to a Receiver.
+ // called before dropping all references to a Receiver.
Release(ctx context.Context)
}
@@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
c := q.control.Clone()
// Don't consume data since we are peeking.
- copied, data, _ = vecCopy(data, q.buffer)
+ copied, _, _ = vecCopy(data, q.buffer)
return copied, copied, c, false, q.addr, notify, nil
}
@@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
return copied, copied, c, cmTruncated, q.addr, notify, nil
}
+// Release implements Receiver.Release.
+func (q *streamQueueReceiver) Release(ctx context.Context) {
+ q.queueReceiver.Release(ctx)
+ q.control.Release(ctx)
+}
+
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
type ConnectedEndpoint interface {
// Passcred implements Endpoint.Passcred.
@@ -619,7 +627,7 @@ type ConnectedEndpoint interface {
SendMaxQueueSize() int64
// Release releases any resources owned by the ConnectedEndpoint. It should
- // be called before droping all references to a ConnectedEndpoint.
+ // be called before dropping all references to a ConnectedEndpoint.
Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
@@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.PasscredOption:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index a4a76d0a3..adad485a9 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
},
}
s.EnableLeakCheck()
-
return fs.NewFile(ctx, d, flags, &s)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 678355fb9..7a78444dc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// returns a corresponding file description.
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
@@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.EnableLeakCheck()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index 0ea4aab8b..563d60578 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -12,10 +12,12 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/time",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
"//pkg/syserror",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 245d2c5cf..167754537 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -19,10 +19,12 @@ import (
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/pkg/syserror"
@@ -57,7 +59,7 @@ type SaveOpts struct {
}
// Save saves the system state.
-func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
+func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error {
log.Infof("Sandbox save started, pausing all tasks.")
k.Pause()
k.ReceiveTaskStates()
@@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
err = ErrStateFile{err}
} else {
// Save the kernel.
- err = k.SaveTo(wc)
+ err = k.SaveTo(ctx, wc)
// ENOSPC is a state file error. This error can only come from
// writing the state file, and not from fs.FileOperations.Fsync
@@ -108,7 +110,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
+func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(r, n, clocks)
+ return k.LoadFrom(ctx, r, n, clocks, vfsOpts)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 9c9def7cd..bb1f715e2 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 849a47476..f7135ea46 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return 0, syserror.EINVAL
}
- r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
+ r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize)
r.SetFlags(linuxToFlags(flags).Settable())
defer r.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 47dadb800..e383a0a87 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -129,13 +129,27 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getPID(t, id, num)
return uintptr(v), nil, err
+ case linux.IPC_STAT:
+ arg := args[3].Pointer()
+ ds, err := ipcStat(t, id)
+ if err == nil {
+ _, err = ds.CopyOut(t, arg)
+ }
+
+ return 0, nil, err
+
+ case linux.GETZCNT:
+ v, err := getZCnt(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.GETNCNT:
+ v, err := getNCnt(t, id, num)
+ return uintptr(v), nil, err
+
case linux.IPC_INFO,
linux.SEM_INFO,
- linux.IPC_STAT,
linux.SEM_STAT,
- linux.SEM_STAT_ANY,
- linux.GETNCNT,
- linux.GETZCNT:
+ linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -171,6 +185,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP
return set.Change(t, creds, owner, perms)
}
+func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.GetStat(creds)
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
@@ -240,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
}
return int32(tg.ID()), nil
}
+
+func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountZeroWaiters(num, creds)
+}
+
+func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountNegativeWaiters(num, creds)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 46616c961..1c4cdb0dd 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inCh chan struct{}
outCh chan struct{}
)
+
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
opts.Length -= n
@@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inW, _ := waiter.NewChannelEntry(inCh)
inFile.EventRegister(&inW, EventMaskRead)
defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Need to refresh readiness.
+ continue
}
if err = t.Block(inCh); err != nil {
break
}
}
- if outFile.Readiness(EventMaskWrite) == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, EventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK.
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ // We might be ready to write now. Try again before
+ // blocking.
+ continue
+ }
+ if err = t.Block(outCh); err != nil {
+ break
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 035e2a6b0..9ce4f280a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
// waitForOut waits for dw.outfile to be read.
func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
- if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if dw.outCh == nil {
- dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
- dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
- // We might be ready now. Try again before blocking.
- return nil
- }
- if err := t.Block(dw.outCh); err != nil {
- return err
- }
- }
- return nil
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK. See b/172075629, b/170743336.
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready to write now. Try again before blocking.
+ return nil
+ }
+ return t.Block(dw.outCh)
}
// destroy cleans up resources help by dw. No more calls to wait* can occur
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index c855608db..440c9307c 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -32,7 +32,7 @@ go_template_instance(
out = "file_description_refs.go",
package = "vfs",
prefix = "FileDescription",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FileDescription",
},
@@ -43,7 +43,7 @@ go_template_instance(
out = "mount_namespace_refs.go",
package = "vfs",
prefix = "MountNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "MountNamespace",
},
@@ -54,7 +54,7 @@ go_template_instance(
out = "filesystem_refs.go",
package = "vfs",
prefix = "Filesystem",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Filesystem",
},
@@ -87,6 +87,7 @@ go_library(
"pathname.go",
"permissions.go",
"resolving_path.go",
+ "save_restore.go",
"vfs.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -99,6 +100,7 @@ go_library(
"//pkg/gohacks",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 8f36c3e3b..a98aac52b 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -74,7 +74,7 @@ type epollInterestKey struct {
// +stateify savable
type epollInterest struct {
// epoll is the owning EpollInstance. epoll is immutable.
- epoll *EpollInstance
+ epoll *EpollInstance `state:"wait"`
// key is the file to which this epollInterest applies. key is immutable.
key epollInterestKey
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 183957ad8..546e445aa 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 2d27d9d35..ba6e6ed49 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
return vfs.PrependPathAtVFSRootError{}
}
- if &d.vfsd == mnt.Root() {
+ if mnt != nil && &d.vfsd == mnt.Root() {
return nil
}
if d.parent == nil {
@@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
d = d.parent
}
}
+
+// DebugPathname returns a pathname to d relative to its filesystem root.
+// DebugPathname does not correspond to any Linux function; it's used to
+// generate dentry pathnames for debugging.
+func DebugPathname(d *Dentry) string {
+ var b fspath.Builder
+ _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ return b.String()
+}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 3f0b8f45b..107171b61 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -65,7 +65,7 @@ type Inotify struct {
// queue is used to notify interested parties when the inotify instance
// becomes readable or writable.
- queue waiter.Queue `state:"nosave"`
+ queue waiter.Queue
// evMu *only* protects the events list. We need a separate lock while
// queuing events: using mu may violate lock ordering, since at that point
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index 55783d4eb..1ff202f2a 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package lock provides POSIX and BSD style file locking for VFS2 file
-// implementations.
-//
-// The actual implementations can be found in the lock package under
-// sentry/fs/lock.
package vfs
import (
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 78f115bfa..3ea981ad4 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -106,6 +107,7 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount
if opts.ReadOnly {
mnt.setReadOnlyLocked(true)
}
+ refsvfs2.Register(mnt)
return mnt
}
@@ -470,11 +472,12 @@ func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
// tryIncMountedRef does not require that a reference is held on mnt.
func (mnt *Mount) tryIncMountedRef() bool {
for {
- refs := atomic.LoadInt64(&mnt.refs)
- if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
+ r := atomic.LoadInt64(&mnt.refs)
+ if r <= 0 { // r < 0 => MSB set => eagerly unmounted
return false
}
- if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(mnt, r+1)
return true
}
}
@@ -484,29 +487,53 @@ func (mnt *Mount) tryIncMountedRef() bool {
func (mnt *Mount) IncRef() {
// In general, negative values for mnt.refs are valid because the MSB is
// the eager-unmount bit.
- atomic.AddInt64(&mnt.refs, 1)
+ r := atomic.AddInt64(&mnt.refs, 1)
+ refsvfs2.LogIncRef(mnt, r)
}
// DecRef decrements mnt's reference count.
func (mnt *Mount) DecRef(ctx context.Context) {
- refs := atomic.AddInt64(&mnt.refs, -1)
- if refs&^math.MinInt64 == 0 { // mask out MSB
- var vd VirtualDentry
- if mnt.parent() != nil {
- mnt.vfs.mountMu.Lock()
- mnt.vfs.mounts.seq.BeginWrite()
- vd = mnt.vfs.disconnectLocked(mnt)
- mnt.vfs.mounts.seq.EndWrite()
- mnt.vfs.mountMu.Unlock()
- }
- if mnt.root != nil {
- mnt.root.DecRef(ctx)
- }
- mnt.fs.DecRef(ctx)
- if vd.Ok() {
- vd.DecRef(ctx)
- }
+ r := atomic.AddInt64(&mnt.refs, -1)
+ if r&^math.MinInt64 == 0 { // mask out MSB
+ refsvfs2.Unregister(mnt)
+ mnt.destroy(ctx)
+ }
+}
+
+func (mnt *Mount) destroy(ctx context.Context) {
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
}
+ mnt.fs.DecRef(ctx)
+ if vd.Ok() {
+ vd.DecRef(ctx)
+ }
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (mnt *Mount) RefType() string {
+ return "vfs.Mount"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (mnt *Mount) LeakMessage() string {
+ return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (mnt *Mount) LogRefs() bool {
+ return false
}
// DecRef decrements mntns' reference count.
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index cb8c56bd3..cb882a983 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) {
parent := &Mount{}
point := &Dentry{}
if m := mt.Lookup(parent, point); m != nil {
- t.Errorf("empty mountTable lookup: got %p, wanted nil", m)
+ t.Errorf("Empty mountTable lookup: got %p, wanted nil", m)
}
}
@@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
mi, ok := ms.Load(k)
if !ok {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
m := mi.(*Mount)
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m, _ := ms.Load(k)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index b7d122d22..cb48c37a1 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -98,7 +98,6 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- // FIXME(gvisor.dev/issue/1663): Slots need to be saved.
slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
@@ -212,6 +211,26 @@ loop:
}
}
+// Range calls f on each Mount in mt. If f returns false, Range stops iteration
+// and returns immediately.
+func (mt *mountTable) Range(f func(*Mount) bool) {
+ tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
+ slotPtr := mt.slots
+ last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes))
+ for {
+ slot := (*mountSlot)(slotPtr)
+ if slot.value != nil {
+ if !f((*Mount)(slot.value)) {
+ return
+ }
+ }
+ if slotPtr == last {
+ return
+ }
+ slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes)
+ }
+}
+
// Insert inserts the given mount into mt.
//
// Preconditions: mt must not already contain a Mount with the same mount point
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
new file mode 100644
index 000000000..7723ed643
--- /dev/null
+++ b/pkg/sentry/vfs/save_restore.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// FilesystemImplSaveRestoreExtension is an optional extension to
+// FilesystemImpl.
+type FilesystemImplSaveRestoreExtension interface {
+ // PrepareSave prepares this filesystem for serialization.
+ PrepareSave(ctx context.Context) error
+
+ // CompleteRestore completes restoration from checkpoint for this
+ // filesystem after deserialization.
+ CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error
+}
+
+// PrepareSave prepares all filesystems for serialization.
+func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.PrepareSave(ctx); err != nil {
+ ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to prepare for serialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestore completes restoration from checkpoint for all filesystems
+// after deserialization.
+func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.CompleteRestore(ctx, *opts); err != nil {
+ ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestoreOptions contains options to
+// VirtualFilesystem.CompleteRestore() and
+// FilesystemImplSaveRestoreExtension.CompleteRestore().
+type CompleteRestoreOptions struct {
+ // If ValidateFileSizes is true, filesystem implementations backed by
+ // remote filesystems should verify that file sizes have not changed
+ // between checkpoint and restore.
+ ValidateFileSizes bool
+
+ // If ValidateFileModificationTimestamps is true, filesystem
+ // implementations backed by remote filesystems should validate that file
+ // mtimes have not changed between checkpoint and restore.
+ ValidateFileModificationTimestamps bool
+}
+
+// saveMounts is called by stateify.
+func (vfs *VirtualFilesystem) saveMounts() []*Mount {
+ if atomic.LoadPointer(&vfs.mounts.slots) == nil {
+ // vfs.Init() was never called.
+ return nil
+ }
+ var mounts []*Mount
+ vfs.mounts.Range(func(mount *Mount) bool {
+ mounts = append(mounts, mount)
+ return true
+ })
+ return mounts
+}
+
+// loadMounts is called by stateify.
+func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
+ if mounts == nil {
+ return
+ }
+ vfs.mounts.Init()
+ for _, mount := range mounts {
+ vfs.mounts.Insert(mount)
+ }
+}
+
+func (mnt *Mount) afterLoad() {
+ if atomic.LoadInt64(&mnt.refs) != 0 {
+ refsvfs2.Register(mnt)
+ }
+}
+
+// afterLoad is called by stateify.
+func (epi *epollInterest) afterLoad() {
+ // Mark all epollInterests as ready after restore so that the next call to
+ // EpollInstance.ReadEvents() rechecks their readiness.
+ epi.Callback(nil)
+}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 38d2701d2..48d6252f7 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -71,7 +71,7 @@ type VirtualFilesystem struct {
// points.
//
// mounts is analogous to Linux's mount_hashtable.
- mounts mountTable
+ mounts mountTable `state:".([]*Mount)"`
// mountpoints maps mount points to mounts at those points in all
// namespaces. mountpoints is protected by mountMu.
@@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre
// SyncAllFilesystems has the semantics of Linux's sync(2).
func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ var retErr error
+ for fs := range vfs.getFilesystems() {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef(ctx)
+ }
+ return retErr
+}
+
+func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} {
fss := make(map[*Filesystem]struct{})
vfs.filesystemsMu.Lock()
+ defer vfs.filesystemsMu.Unlock()
for fs := range vfs.filesystems {
if !fs.TryIncRef() {
continue
}
fss[fs] = struct{}{}
}
- vfs.filesystemsMu.Unlock()
- var retErr error
- for fs := range fss {
- if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
- retErr = err
- }
- fs.DecRef(ctx)
- }
- return retErr
+ return fss
}
// MkdirAllAt recursively creates non-existent directories on the given path
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
index f08599ebd..cb0001852 100644
--- a/pkg/shim/runsc/BUILD
+++ b/pkg/shim/runsc/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "@com_github_containerd_containerd//log:go_default_library",
"@com_github_containerd_go_runc//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
index c5cf68efa..e7c9640ba 100644
--- a/pkg/shim/runsc/runsc.go
+++ b/pkg/shim/runsc/runsc.go
@@ -28,10 +28,12 @@ import (
"syscall"
"time"
+ "github.com/containerd/containerd/log"
runc "github.com/containerd/go-runc"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
+// Monitor is the default process monitor to be used by runsc.
var Monitor runc.ProcessMonitor = runc.Monitor
// DefaultCommand is the default command for Runsc.
@@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro
return &c, nil
}
+// CreateOpts is a set of options to Runsc.Create().
type CreateOpts struct {
runc.IO
ConsoleSocket runc.ConsoleSocket
@@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) {
return res.ExitStatus, nil
}
+// ExecOpts is a set of options to runsc.Exec().
type ExecOpts struct {
runc.IO
PidFile string
@@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts
return Monitor.Wait(cmd, ec)
}
+// DeleteOpts is a set of options to runsc.Delete().
type DeleteOpts struct {
Force bool
}
@@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
if err := json.NewDecoder(rd).Decode(&e); err != nil {
return nil, err
}
+ log.L.Debugf("Stats returned: %+v", e.Stats)
+ if e.Type != "stats" {
+ return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type)
+ }
+ if e.Stats == nil {
+ return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`)
+ }
return e.Stats, nil
}
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 089b3bbef..92c51879b 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "pending_list",
- out = "pending_list.go",
- package = "state",
- prefix = "pending",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectEncodeState",
- "ElementMapper": "pendingMapper",
- "Linker": "*pendingEntry",
- },
-)
-
-go_template_instance(
name = "deferred_list",
out = "deferred_list.go",
package = "state",
@@ -83,7 +70,6 @@ go_library(
"deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "pending_list.go",
"state.go",
"state_norace.go",
"state_race.go",
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 89467ca8e..e519ddeca 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -21,6 +21,7 @@ import (
"math"
"reflect"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/state/wire"
)
@@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
-// directly assignable. See unsafePointerTo.
+// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
@@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
- if v.IsValid() {
- obj.Set(unsafePointerTo(v))
- }
+ obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
@@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
- obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
@@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
ds.pending.PushBack(rootOds)
// Read the number of objects.
- lastID, object, err := ReadHeader(ds.r)
+ numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
@@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) {
var (
encoded wire.Object
ods *objectDecodeState
- id = objectID(1)
+ id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
- // Note that the structure of this decoding loop should match
- // the raw decoding loop in printer.go.
- for id <= objectID(lastID) {
- // Unmarshal the object.
+ // Note that the structure of this decoding loop should match the raw
+ // decoding loop in state/pretty/pretty.printer.printStream().
+ for i := uint64(0); i < numObjects; {
+ // Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
-
- // Is this a type object? Handle inline.
- if wt, ok := encoded.(*wire.Type); ok {
- ds.types.Register(wt)
+ switch we := encoded.(type) {
+ case *wire.Type:
+ ds.types.Register(we)
tid++
encoded = nil
continue
+ case wire.Uint:
+ id = objectID(we)
+ i++
+ // Unmarshal and resolve the actual object.
+ encoded = wire.Load(ds.r)
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+ // For error handling.
+ ods = nil
+ encoded = nil
+ default:
+ Failf("wanted type or object ID, got %#v", encoded)
}
-
- // Actually resolve the object.
- ods = ds.lookup(id)
- if ods != nil {
- // Decode the object.
- ds.decodeObject(ods, ods.obj, encoded)
- } else {
- // If an object hasn't had interest registered
- // previously or isn't yet valid, we deferred
- // decoding until interest is registered.
- ds.deferred[id] = encoded
- }
-
- // For error handling.
- ods = nil
- encoded = nil
- id++
}
}); err != nil {
// Include as much information as we can, taking into account
@@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) {
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
- Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
+ numDeferred := 0
for id, encoded := range ds.deferred {
- // Shoud never happen, the graph was bogus.
- Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ numDeferred++
+ if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
+ typ := ds.types.LookupType(typeID(s.TypeID))
+ log.Warningf("unused deferred object: ID %d, type %v", id, typ)
+ } else {
+ log.Warningf("unused deferred object: ID %d, %#v", id, encoded)
+ }
+ }
+ if numDeferred != 0 {
+ Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
index d048f61a1..f1208e2a2 100644
--- a/pkg/state/decode_unsafe.go
+++ b/pkg/state/decode_unsafe.go
@@ -15,13 +15,62 @@
package state
import (
+ "fmt"
"reflect"
+ "runtime"
"unsafe"
)
-// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
-// values representing unexported fields. This bypasses visibility, but not
-// type safety.
-func unsafePointerTo(obj reflect.Value) reflect.Value {
+// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned
+// reflect.Value is usable in assignments even if obj was obtained by the use
+// of unexported struct fields.
+//
+// Preconditions: obj.CanAddr().
+func reflectValueRWAddr(obj reflect.Value) reflect.Value {
return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
}
+
+// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the
+// returned reflect.Value is usable in assignments even if obj was obtained by
+// the use of unexported struct fields.
+//
+// Preconditions:
+// * arr.Kind() == reflect.Array.
+// * i, j, k >= 0.
+// * i <= j <= k <= arr.Len().
+func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value {
+ if arr.Kind() != reflect.Array {
+ panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array))
+ }
+ if i < 0 || j < 0 || k < 0 {
+ panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k))
+ }
+ if i > j {
+ panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j))
+ }
+ if j > k {
+ panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k))
+ }
+ if k > arr.Len() {
+ panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len()))
+ }
+
+ sliceTyp := reflect.SliceOf(arr.Type().Elem())
+ if i == arr.Len() {
+ // By precondition, i == j == k == arr.Len().
+ return reflect.MakeSlice(sliceTyp, 0, 0)
+ }
+ slh := reflect.SliceHeader{
+ // reflect.Value.CanAddr() == false for arrays, so we need to get the
+ // address from the first element of the array.
+ Data: arr.Index(i).UnsafeAddr(),
+ Len: j - i,
+ Cap: k - i,
+ }
+ slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem()
+ // Before slobj is constructed, arr holds the only pointer-typed pointer to
+ // the array since reflect.SliceHeader.Data is a uintptr, so arr must be
+ // kept alive.
+ runtime.KeepAlive(arr)
+ return slobj
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 92fcad4e9..560e7c2a3 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -17,13 +17,14 @@ package state
import (
"context"
"reflect"
+ "sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectEncodeState the type and identity of an object occupying a memory
// address range. This is the value type for addrSet, and the intrusive entry
-// for the pending and deferred lists.
+// for the deferred list.
type objectEncodeState struct {
// id is the assigned ID for this object.
id objectID
@@ -47,7 +48,6 @@ type objectEncodeState struct {
// references may be updated directly and automatically.
refs []*wire.Ref
- pendingEntry
deferredEntry
}
@@ -93,9 +93,15 @@ type encodeState struct {
// serialized.
pendingTypes []wire.Type
- // pending is the list of objects to be serialized. Serialization does
+ // pending maps object IDs to objects to be serialized. Serialization does
// not actually occur until the full object graph is computed.
- pending pendingList
+ pending map[objectID]*objectEncodeState
+
+ // encodedStructs maps reflect.Values representing structs to previous
+ // encodings of those structs. This is necessary to avoid duplicate calls
+ // to SaverLoader.StateSave() that may result in multiple calls to
+ // Sink.SaveValue() for a given field, resulting in object duplication.
+ encodedStructs map[reflect.Value]*wire.Struct
// stats tracks time data.
stats Stats
@@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
// depending on this value knows there's nothing there.
return
}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg, gap := es.values.Find(addr)
+ if seg.Ok() {
// Ensure the map types match.
existing := seg.Value()
if existing.obj.Type() != obj.Type() {
@@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
}
// Record the map.
+ r := addrRange{addr, addr + 1}
oes := &objectEncodeState{
id: es.nextID(),
obj: obj,
how: encodeMapAsValue,
}
- es.values.Add(addrRange{addr, addr + 1}, oes)
- es.pending.PushBack(oes)
+ // Use Insert instead of InsertWithoutMergingUnchecked when race
+ // detection is enabled to get additional sanity-checking from Merge.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
// See above: no ref recording.
@@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
obj: obj,
}
es.zeroValues[typ] = oes
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
}
@@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
size = 1 // See above.
}
- // Calculate the container.
end := addr + size
r := addrRange{addr, end}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg := es.values.LowerBoundSegment(addr)
+ var (
+ oes *objectEncodeState
+ gap addrGapIterator
+ )
+
+ // Does at least one previously-registered object overlap this one?
+ if seg.Ok() && seg.Start() < end {
existing := seg.Value()
- switch {
- case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
- // The object is a perfect match. Happy path. Avoid the
- // traversal and just return directly. We don't need to
- // encode the type information or any dots here.
+
+ if seg.Range() == r && typ == existing.obj.Type() {
+ // This exact object is already registered. Avoid the traversal and
+ // just return directly. We don't need to encode the type
+ // information or any dots here.
ref.Root = wire.Uint(existing.id)
existing.refs = append(existing.refs, ref)
return
+ }
- case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
- // The previously registered object is larger than
- // this, no need to update. But we expect some
- // traversal below.
+ if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) {
+ // This object is contained within a previously-registered object.
+ // Perform traversal from the container to the new object.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr)
+ ref.Type = es.findType(existing.obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
- case seg.Start() == addr && seg.End() == end:
- if !isSameSizeParent(obj, existing.obj.Type()) {
- break // Needs traversal.
+ // This object contains one or more previously-registered objects.
+ // Remove them and update existing references to use the new one.
+ oes := &objectEncodeState{
+ // Reuse the root ID of the first contained element.
+ id: existing.id,
+ obj: obj,
+ }
+ type elementEncodeState struct {
+ addr uintptr
+ typ reflect.Type
+ refs []*wire.Ref
+ }
+ var (
+ elems []elementEncodeState
+ gap addrGapIterator
+ )
+ for {
+ // Each contained object should be completely contained within
+ // this one.
+ if raceEnabled && !r.IsSupersetOf(seg.Range()) {
+ Failf("containing object %#v does not contain existing object %#v", obj, existing.obj)
}
- fallthrough // Needs update.
-
- case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
- // Update the object and redo the encoding.
- old := existing.obj
- existing.obj = obj
+ elems = append(elems, elementEncodeState{
+ addr: seg.Start(),
+ typ: existing.obj.Type(),
+ refs: existing.refs,
+ })
+ delete(es.pending, existing.id)
es.deferred.Remove(existing)
- es.deferred.PushBack(existing)
-
- // The previously registered object is superseded by
- // this new object. We are guaranteed to not have any
- // mergeable neighbours in this segment set.
- if !raceEnabled {
- seg.SetRangeUnchecked(r)
- } else {
- // Add extra paranoid. This will be statically
- // removed at compile time unless a race build.
- es.values.Remove(seg)
- es.values.Add(r, existing)
- seg = es.values.LowerBoundSegment(addr)
+ gap = es.values.Remove(seg)
+ seg = gap.NextSegment()
+ if !seg.Ok() || seg.Start() >= end {
+ break
}
-
- // Compute the traversal required & update references.
- dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
- wt := es.findType(obj.Type())
- for _, ref := range existing.refs {
+ existing = seg.Value()
+ }
+ wt := es.findType(typ)
+ for _, elem := range elems {
+ dots := traverse(typ, elem.typ, addr, elem.addr)
+ for _, ref := range elem.refs {
+ ref.Root = wire.Uint(oes.id)
ref.Dots = append(ref.Dots, dots...)
ref.Type = wt
}
- default:
- // There is a non-sensical overlap.
- Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ oes.refs = append(oes.refs, elem.refs...)
}
-
- // Compute the new reference, record and return it.
- ref.Root = wire.Uint(existing.id)
- ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
- ref.Type = es.findType(obj.Type())
- existing.refs = append(existing.refs, ref)
+ // Finally register the new containing object.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
return
}
- // The only remaining case is a pointer value that doesn't overlap with
- // any registered addresses. Create a new entry for it, and start
- // tracking the first reference we just created.
- oes := &objectEncodeState{
+ // No existing object overlaps this one. Register a new object.
+ oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
+ if seg.Ok() {
+ gap = seg.PrevGap()
+ } else {
+ gap = es.values.LastGap()
+ }
if !raceEnabled {
- es.values.AddWithoutMerging(r, oes)
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
- // Merges should never happen. This is just enabled extra
- // sanity checks because the Merge function below will panic.
- es.values.Add(r, oes)
+ es.values.Insert(gap, r, oes)
}
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
@@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) {
// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ if s, ok := es.encodedStructs[obj]; ok {
+ *dest = s
+ return
+ }
+ s := &wire.Struct{}
+ *dest = s
+ es.encodedStructs[obj] = s
+
// Ensure that the obj is addressable. There are two cases when it is
// not. First, is when this is dispatched via SaveValue. Second, when
// this is a map key as a struct. Either way, we need to make a copy to
@@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
obj = localObj.Elem()
}
- // Prepare the value.
- s := &wire.Struct{}
- *dest = s
-
// Look the type up in the database.
te, ok := es.types.Lookup(obj.Type())
if te == nil {
@@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) {
Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
}
- // Check that items are pending.
- if es.pending.Front() == nil {
+ // Check that we have objects to serialize.
+ if len(es.pending) == 0 {
Failf("pending is empty?")
}
- // Write the header with the number of objects. Note that there is no
- // way that es.lastID could conflict with objectID, which would
- // indicate that an impossibly large encoding.
- if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ // Write the header with the number of objects.
+ if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
// Serialize all pending types and pending objects. Note that we don't
// bother removing from this list as we walk it because that just
// wastes time. It will not change after this point.
- var id objectID
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
}
- for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
- id++ // First object is 1.
- if oes.id != id {
- Failf("expected id %d, got %d", id, oes.id)
- }
-
- // Marshall the object.
+ // Emit objects in ID order.
+ ids := make([]objectID, 0, len(es.pending))
+ for id := range es.pending {
+ ids = append(ids, id)
+ }
+ sort.Slice(ids, func(i, j int) bool {
+ return ids[i] < ids[j]
+ })
+ for _, id := range ids {
+ // Encode the id.
+ wire.Save(es.w, wire.Uint(id))
+ // Marshal the object.
+ oes := es.pending[id]
wire.Save(es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
Failf("error serializing object %#v: %w", oes.encoded, err)
}
-
- // Check what we wrote.
- if id != es.lastID {
- Failf("expected %d objects, wrote %d", es.lastID, id)
- }
}
// objectFlag indicates that the length is a # of objects, rather than a raw
@@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error {
})
}
-// pendingMapper is for the pending list.
-type pendingMapper struct{}
-
-func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
-
// deferredMapper is for the deferred list.
type deferredMapper struct{}
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
index 887f453a9..c6e8bb31d 100644
--- a/pkg/state/pretty/pretty.go
+++ b/pkg/state/pretty/pretty.go
@@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
buf.WriteString(typ)
buf.WriteString(")(")
buf.WriteString(baseRef)
+ buf.WriteString(")")
for _, component := range x.Dots {
switch v := component.(type) {
case *wire.FieldName:
@@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
}
}
- buf.WriteString(")")
fullRef = buf.String()
}
if p.html {
@@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
// Note that this loop must match the general structure of the
// loop in decode.go. But we don't register type information,
// etc. and just print the raw structures.
+ type objectAndID struct {
+ id uint64
+ obj wire.Object
+ }
var (
tid uint64 = 1
- objects []wire.Object
+ objects []objectAndID
)
- for oid := uint64(1); oid <= length; {
- // Unmarshal the object.
+ for i := uint64(0); i < length; {
+ // Unmarshal either a type object or object ID.
encoded := wire.Load(r)
-
- // Is this a type?
- if typ, ok := encoded.(*wire.Type); ok {
+ switch we := encoded.(type) {
+ case *wire.Type:
str, _ := p.format(graph, 0, encoded)
tag := fmt.Sprintf("g%dt%d", graph, tid)
- p.typeSpecs[tag] = typ
+ p.typeSpecs[tag] = we
if p.html {
// See below.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
@@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
return err
}
tid++
- continue
+ case wire.Uint:
+ // Unmarshal the actual object.
+ objects = append(objects, objectAndID{
+ id: uint64(we),
+ obj: wire.Load(r),
+ })
+ i++
+ default:
+ return fmt.Errorf("wanted type or object ID, got %#v", encoded)
}
-
- // Otherwise, it is a node.
- objects = append(objects, encoded)
- oid++
}
- for i, encoded := range objects {
- // oid starts at 1.
- oid := i + 1
+ for _, objAndID := range objects {
// Format the node.
- str, _ := p.format(graph, 0, encoded)
- tag := fmt.Sprintf("g%dr%d", graph, oid)
+ str, _ := p.format(graph, 0, objAndID.obj)
+ tag := fmt.Sprintf("g%dr%d", graph, objAndID.id)
if p.html {
// Create a little tag with an anchor next to it for linking.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
diff --git a/pkg/state/state.go b/pkg/state/state.go
index acb629969..6b8540f03 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error {
func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
es := encodeState{
- ctx: ctx,
- w: w,
- types: makeTypeEncodeDatabase(),
- zeroValues: make(map[reflect.Type]*objectEncodeState),
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ pending: make(map[objectID]*objectEncodeState),
+ encodedStructs: make(map[reflect.Value]*wire.Struct),
}
// Perform the encoding.
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
index bd2c2b399..69143d194 100644
--- a/pkg/state/tests/struct.go
+++ b/pkg/state/tests/struct.go
@@ -54,12 +54,47 @@ type outerArray struct {
}
// +stateify savable
+type outerSlice struct {
+ inner []inner
+}
+
+// +stateify savable
type inner struct {
v int64
}
// +stateify savable
+type outerFieldValue struct {
+ inner innerFieldValue
+}
+
+// +stateify savable
+type innerFieldValue struct {
+ v int64 `state:".(*savedFieldValue)"`
+}
+
+// +stateify savable
+type savedFieldValue struct {
+ v int64
+}
+
+func (ifv *innerFieldValue) saveV() *savedFieldValue {
+ return &savedFieldValue{ifv.v}
+}
+
+func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) {
+ ifv.v = sfv.v
+}
+
+// +stateify savable
type system struct {
v1 interface{}
v2 interface{}
}
+
+// +stateify savable
+type system3 struct {
+ v1 interface{}
+ v2 interface{}
+ v3 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index de9d17aa7..c91c2c032 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -15,6 +15,7 @@
package tests
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/state"
@@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) {
}
func TestEmbeddedPointers(t *testing.T) {
- var (
- ofs outerSame
- of1 outerFieldFirst
- of2 outerFieldSecond
- oa outerArray
- )
+ // Give each int64 a random value to prevent Go from using
+ // runtime.staticuint64s, which confounds tests for struct duplication.
+ magic := func() int64 {
+ for {
+ n := rand.Int63()
+ if n < 0 || n > 255 {
+ return n
+ }
+ }
+ }
+
+ ofs := outerSame{inner{magic()}}
+ of1 := outerFieldFirst{inner{magic()}, magic()}
+ of2 := outerFieldSecond{magic(), inner{magic()}}
+ oa := outerArray{[2]inner{{magic()}, {magic()}}}
+ osl := outerSlice{oa.inner[:]}
+ ofv := outerFieldValue{innerFieldValue{magic()}}
runTestCases(t, false, "embedded-pointers", []interface{}{
system{&ofs, &ofs.inner},
@@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) {
system{&oa, &oa.inner[1]},
system{&oa.inner[0], &oa},
system{&oa.inner[1], &oa},
+ system3{&oa, &oa.inner[0], &oa.inner[1]},
+ system3{&oa, &oa.inner[1], &oa.inner[0]},
+ system3{&oa.inner[0], &oa, &oa.inner[1]},
+ system3{&oa.inner[1], &oa, &oa.inner[0]},
+ system3{&oa.inner[0], &oa.inner[1], &oa},
+ system3{&oa.inner[1], &oa.inner[0], &oa},
+ system{&oa, &osl},
+ system{&osl, &oa},
+ system{&ofv, &ofv.inner},
+ system{&ofv.inner, &ofv},
})
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 12b061def..b196324c7 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -97,6 +97,9 @@ type testConnection struct {
func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, err
+ }
entry, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&entry, waiter.EventOut)
@@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
// Give c.Read() a chance to block before closing the connection.
@@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
c.SetDeadline(time.Now().Add(time.Minute))
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 6f81b0164..530f2ae2f 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -205,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker {
if !ok {
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
}
- options := ip.Options()
+ options := []byte(ip.Options())
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(options) == 0 {
return
@@ -859,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker {
}
}
+// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
+func ICMPv4Pointer(want uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Pointer(); got != want {
+ t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
+ }
+ }
+}
+
// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
// This assumes that the payload exactly makes up the rest of the slice.
func ICMPv4Checksum() TransportChecker {
@@ -953,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
}
}
+// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
+// field.
+func ICMPv6TypeSpecific(want uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ if got := icmpv6.TypeSpecific(); got != want {
+ t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
+func ICMPv6Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ payload := icmpv6.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 504408878..2f13dea6a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -99,7 +99,8 @@ const (
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4ReassemblyTimeout ICMPv4Code = 1
)
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
@@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
+// Pointer returns the pointer field in a Parameter Problem packet.
+func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
+
+// SetPointer sets the pointer field in a Parameter Problem packet.
+func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
+
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 4c6e4be64..961b77628 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,7 +39,6 @@ import (
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Options | Padding |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-//
const (
versIHL = 0
tos = 1
@@ -93,7 +93,7 @@ type IPv4Fields struct {
DstAddr tcpip.Address
}
-// IPv4 represents an ipv4 header stored in a byte array.
+// IPv4 is an IPv4 header.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other
@@ -106,10 +106,13 @@ const (
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
- // that there are only 4 bits to represents the header length in 32-bit
- // units, the header cannot exceed 15*4 = 60 bytes.
+ // that there are only 4 bits (max 0xF (15)) to represent the header length
+ // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumOptionsSize is the largest size the IPv4 options can be.
+ IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize
+
// IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
//
// Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
@@ -130,7 +133,7 @@ const (
// IPv4ProtocolNumber is IPv4's network protocol number.
IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
- // IPv4Version is the version of the ipv4 protocol.
+ // IPv4Version is the version of the IPv4 protocol.
IPv4Version = 4
// IPv4AllSystems is the all systems IPv4 multicast address as per
@@ -148,6 +151,13 @@ const (
// packet that every IPv4 capable host must be able to
// process/reassemble.
IPv4MinimumProcessableDatagramSize = 576
+
+ // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791,
+ // section 3.2:
+ // Every internet module must be able to forward a datagram of 68 octets
+ // without further fragmentation. This is because an internet header may be
+ // up to 60 octets, and the minimum fragment is 8 octets.
+ IPv4MinimumMTU = 68
)
// Flags that may be set in an IPv4 packet.
@@ -191,14 +201,13 @@ func IPVersion(b []byte) int {
// Internet Header Length is the length of the internet header in 32
// bit words, and thus points to the beginning of the data. Note that
// the minimum value for a correct header is 5.
-//
const (
ipVersionShift = 4
ipIHLMask = 0x0f
IPv4IHLStride = 4
)
-// HeaderLength returns the value of the "header length" field of the ipv4
+// HeaderLength returns the value of the "header length" field of the IPv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & ipIHLMask) * IPv4IHLStride
@@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) {
b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
-// ID returns the value of the identifier field of the ipv4 header.
+// ID returns the value of the identifier field of the IPv4 header.
func (b IPv4) ID() uint16 {
return binary.BigEndian.Uint16(b[id:])
}
-// Protocol returns the value of the protocol field of the ipv4 header.
+// Protocol returns the value of the protocol field of the IPv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
-// Flags returns the "flags" field of the ipv4 header.
+// Flags returns the "flags" field of the IPv4 header.
func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
@@ -232,41 +241,44 @@ func (b IPv4) More() bool {
return b.Flags()&IPv4FlagMoreFragments != 0
}
-// TTL returns the "TTL" field of the ipv4 header.
+// TTL returns the "TTL" field of the IPv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
}
-// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+// FragmentOffset returns the "fragment offset" field of the IPv4 header.
func (b IPv4) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
}
-// TotalLength returns the "total length" field of the ipv4 header.
+// TotalLength returns the "total length" field of the IPv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
-// Checksum returns the checksum field of the ipv4 header.
+// Checksum returns the checksum field of the IPv4 header.
func (b IPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[checksum:])
}
-// SourceAddress returns the "source address" field of the ipv4 header.
+// SourceAddress returns the "source address" field of the IPv4 header.
func (b IPv4) SourceAddress() tcpip.Address {
return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
}
-// DestinationAddress returns the "destination address" field of the ipv4
+// DestinationAddress returns the "destination address" field of the IPv4
// header.
func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// Options returns a a buffer holding the options.
-func (b IPv4) Options() []byte {
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
+// Options returns a buffer holding the options.
+func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
- return b[options:hdrLen:hdrLen]
+ return IPv4Options(b[options:hdrLen:hdrLen])
}
// TransportProtocol implements Network.TransportProtocol.
@@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte {
return b[b.HeaderLength():][:b.PayloadLength()]
}
-// PayloadLength returns the length of the payload portion of the ipv4 packet.
+// PayloadLength returns the length of the payload portion of the IPv4 packet.
func (b IPv4) PayloadLength() uint16 {
return b.TotalLength() - uint16(b.HeaderLength())
}
-// TOS returns the "type of service" field of the ipv4 header.
+// TOS returns the "type of service" field of the IPv4 header.
func (b IPv4) TOS() (uint8, uint32) {
return b[tos], 0
}
-// SetTOS sets the "type of service" field of the ipv4 header.
+// SetTOS sets the "type of service" field of the IPv4 header.
func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
@@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) {
b[ttl] = v
}
-// SetTotalLength sets the "total length" field of the ipv4 header.
+// SetTotalLength sets the "total length" field of the IPv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
-// SetChecksum sets the checksum field of the ipv4 header.
+// SetChecksum sets the checksum field of the IPv4 header.
func (b IPv4) SetChecksum(v uint16) {
binary.BigEndian.PutUint16(b[checksum:], v)
}
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
-// ipv4 header.
+// IPv4 header.
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
@@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) {
binary.BigEndian.PutUint16(b[id:], v)
}
-// SetSourceAddress sets the "source address" field of the ipv4 header.
+// SetSourceAddress sets the "source address" field of the IPv4 header.
func (b IPv4) SetSourceAddress(addr tcpip.Address) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
}
-// SetDestinationAddress sets the "destination address" field of the ipv4
+// SetDestinationAddress sets the "destination address" field of the IPv4
// header.
func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
}
-// CalculateChecksum calculates the checksum of the ipv4 header.
+// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return Checksum(b[:b.HeaderLength()], 0)
}
-// Encode encodes all the fields of the ipv4 header.
+// Encode encodes all the fields of the IPv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
@@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
}
-// EncodePartial updates the total length and checksum fields of ipv4 header,
+// EncodePartial updates the total length and checksum fields of IPv4 header,
// taking in the partial checksum, which is the checksum of the header without
// the total length and checksum fields. It is useful in cases when similar
// packets are produced.
@@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool {
}
return addr[0] == 0x7f
}
+
+// ========================= Options ==========================
+
+// An IPv4OptionType can hold the valuse for the Type in an IPv4 option.
+type IPv4OptionType byte
+
+// These constants are needed to identify individual options in the option list.
+// While RFC 791 (page 31) says "Every internet module must be able to act on
+// every option." This has not generally been adhered to and some options have
+// very low rates of support. We do not support options other than those shown
+// below.
+
+const (
+ // IPv4OptionListEndType is the option type for the End Of Option List
+ // option. Anything following is ignored.
+ IPv4OptionListEndType IPv4OptionType = 0
+
+ // IPv4OptionNOPType is the No-Operation option. May appear between other
+ // options and may appear multiple times.
+ IPv4OptionNOPType IPv4OptionType = 1
+
+ // IPv4OptionRecordRouteType is used by each router on the path of the packet
+ // to record its path. It is carried over to an Echo Reply.
+ IPv4OptionRecordRouteType IPv4OptionType = 7
+
+ // IPv4OptionTimestampType is the option type for the Timestamp option.
+ IPv4OptionTimestampType IPv4OptionType = 68
+
+ // ipv4OptionTypeOffset is the offset in an option of its type field.
+ ipv4OptionTypeOffset = 0
+
+ // IPv4OptionLengthOffset is the offset in an option of its length field.
+ IPv4OptionLengthOffset = 1
+)
+
+// Potential errors when parsing generic IP options.
+var (
+ ErrIPv4OptZeroLength = errors.New("zero length IP option")
+ ErrIPv4OptDuplicate = errors.New("duplicate IP option")
+ ErrIPv4OptInvalid = errors.New("invalid IP option")
+ ErrIPv4OptMalformed = errors.New("malformed IP option")
+ ErrIPv4OptionTruncated = errors.New("truncated IP option")
+ ErrIPv4OptionAddress = errors.New("bad IP option address")
+)
+
+// IPv4Option is an interface representing various option types.
+type IPv4Option interface {
+ // Type returns the type identifier of the option.
+ Type() IPv4OptionType
+
+ // Size returns the size of the option in bytes.
+ Size() uint8
+
+ // Contents returns a slice holding the contents of the option.
+ Contents() []byte
+}
+
+var _ IPv4Option = (*IPv4OptionGeneric)(nil)
+
+// IPv4OptionGeneric is an IPv4 Option of unknown type.
+type IPv4OptionGeneric []byte
+
+// Type implements IPv4Option.
+func (o *IPv4OptionGeneric) Type() IPv4OptionType {
+ return IPv4OptionType((*o)[ipv4OptionTypeOffset])
+}
+
+// Size implements IPv4Option.
+func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) }
+
+// Contents implements IPv4Option.
+func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) }
+
+// IPv4OptionIterator is an iterator pointing to a specific IP option
+// at any point of time. It also holds information as to a new options buffer
+// that we are building up to hand back to the caller.
+type IPv4OptionIterator struct {
+ options IPv4Options
+ // ErrCursor is where we are while parsing options. It is exported as any
+ // resulting ICMP packet is supposed to have a pointer to the byte within
+ // the IP packet where the error was detected.
+ ErrCursor uint8
+ nextErrCursor uint8
+ newOptions [IPv4MaximumOptionsSize]byte
+ writePoint int
+}
+
+// MakeIterator sets up and returns an iterator of options. It also sets up the
+// building of a new option set.
+func (o IPv4Options) MakeIterator() IPv4OptionIterator {
+ return IPv4OptionIterator{
+ options: o,
+ nextErrCursor: IPv4MinimumSize,
+ }
+}
+
+// RemainingBuffer returns the remaining (unused) part of the new option buffer,
+// into which a new option may be written.
+func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options {
+ return IPv4Options(i.newOptions[i.writePoint:])
+}
+
+// ConsumeBuffer marks a portion of the new buffer as used.
+func (i *IPv4OptionIterator) ConsumeBuffer(size int) {
+ i.writePoint += size
+}
+
+// PushNOPOrEnd puts one of the single byte options onto the new options.
+// Only values 0 or 1 (ListEnd or NOP) are valid input.
+func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) {
+ if val > IPv4OptionNOPType {
+ panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val))
+ }
+ i.newOptions[i.writePoint] = byte(val)
+ i.writePoint++
+}
+
+// Finalize returns the completed replacement options buffer padded
+// as needed.
+func (i *IPv4OptionIterator) Finalize() IPv4Options {
+ // RFC 791 page 31 says:
+ // The options might not end on a 32-bit boundary. The internet header
+ // must be filled out with octets of zeros. The first of these would
+ // be interpreted as the end-of-options option, and the remainder as
+ // internet header padding.
+ // Since the buffer is already zero filled we just need to step the write
+ // pointer up to the next multiple of 4.
+ options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3])
+ // Poison the write pointer.
+ i.writePoint = len(i.newOptions)
+ return options
+}
+
+// Next returns the next IP option in the buffer/list of IP options.
+// It returns
+// - A slice of bytes holding the next option or nil if there is error.
+// - A boolean which is true if parsing of all the options is complete.
+// - An error which is non-nil if an error condition was encountered.
+func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
+ // The opts slice gets shorter as we process the options. When we have no
+ // bytes left we are done.
+ if len(i.options) == 0 {
+ return nil, true, nil
+ }
+
+ i.ErrCursor = i.nextErrCursor
+
+ optType := IPv4OptionType(i.options[ipv4OptionTypeOffset])
+
+ if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType {
+ optionBody := i.options[:1]
+ i.options = i.options[1:]
+ i.nextErrCursor = i.ErrCursor + 1
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+ }
+
+ // There are no more single byte options defined. All the rest have a length
+ // field so we need to sanity check it.
+ if len(i.options) == 1 {
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ optLen := i.options[IPv4OptionLengthOffset]
+
+ if optLen == 0 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptZeroLength
+ }
+
+ if optLen == 1 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ if optLen > uint8(len(i.options)) {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptionTruncated
+ }
+
+ optionBody := i.options[:optLen]
+ i.nextErrCursor = i.ErrCursor + optLen
+ i.options = i.options[optLen:]
+
+ // Check the length of some option types that we know.
+ switch optType {
+ case IPv4OptionTimestampType:
+ if optLen < IPv4OptionTimestampHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionTimestamp(optionBody)
+ return &retval, false, nil
+
+ case IPv4OptionRecordRouteType:
+ if optLen < IPv4OptionRecordRouteHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionRecordRoute(optionBody)
+ return &retval, false, nil
+ }
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+}
+
+//
+// IP Timestamp option - RFC 791 page 22.
+// +--------+--------+--------+--------+
+// |01000100| length | pointer|oflw|flg|
+// +--------+--------+--------+--------+
+// | internet address |
+// +--------+--------+--------+--------+
+// | timestamp |
+// +--------+--------+--------+--------+
+// | ... |
+//
+// Type = 68
+//
+// The Option Length is the number of octets in the option counting
+// the type, length, pointer, and overflow/flag octets (maximum
+// length 40).
+//
+// The Pointer is the number of octets from the beginning of this
+// option to the end of timestamps plus one (i.e., it points to the
+// octet beginning the space for next timestamp). The smallest
+// legal value is 5. The timestamp area is full when the pointer
+// is greater than the length.
+//
+// The Overflow (oflw) [4 bits] is the number of IP modules that
+// cannot register timestamps due to lack of space.
+//
+// The Flag (flg) [4 bits] values are
+//
+// 0 -- time stamps only, stored in consecutive 32-bit words,
+//
+// 1 -- each timestamp is preceded with internet address of the
+// registering entity,
+//
+// 3 -- the internet address fields are prespecified. An IP
+// module only registers its timestamp if it matches its own
+// address with the next specified internet address.
+//
+// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC.
+//
+// The Timestamp is a right-justified, 32-bit timestamp in
+// milliseconds since midnight UT. If the time is not available in
+// milliseconds or cannot be provided with respect to midnight UT
+// then any time may be inserted as a timestamp provided the high
+// order bit of the timestamp field is set to one to indicate the
+// use of a non-standard value.
+
+// IPv4OptTSFlags sefines the values expected in the Timestamp
+// option Flags field.
+type IPv4OptTSFlags uint8
+
+//
+// Timestamp option specific related constants.
+const (
+ // IPv4OptionTimestampHdrLength is the length of the timestamp option header.
+ IPv4OptionTimestampHdrLength = 4
+
+ // IPv4OptionTimestampSize is the size of an IP timestamp.
+ IPv4OptionTimestampSize = 4
+
+ // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address.
+ IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize
+
+ // IPv4OptionTimestampMaxSize is limited by space for options
+ IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize
+
+ // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp
+ // is present.
+ IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0
+
+ // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and
+ // IP are present.
+ IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1
+
+ // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that
+ // predefined IP is present.
+ IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3
+)
+
+// ipv4TimestampTime provides the current time as specified in RFC 791.
+func ipv4TimestampTime(clock tcpip.Clock) uint32 {
+ const millisecondsPerDay = 24 * 3600 * 1000
+ const nanoPerMilli = 1000000
+ return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay)
+}
+
+// IP Timestamp option fields.
+const (
+ // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field.
+ IPv4OptTSPointerOffset = 2
+
+ // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow
+ // fields, (each being 4 bits).
+ IPv4OptTSOFLWAndFLGOffset = 3
+ // These constants define the sub byte fields of the Flag and OverFlow field.
+ ipv4OptionTimestampOverflowshift = 4
+ ipv4OptionTimestampFlagsMask byte = 0x0f
+)
+
+var _ IPv4Option = (*IPv4OptionTimestamp)(nil)
+
+// IPv4OptionTimestamp is a Timestamp option from RFC 791.
+type IPv4OptionTimestamp []byte
+
+// Type implements IPv4Option.Type().
+func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType }
+
+// Size implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) }
+
+// Contents implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) }
+
+// Pointer returns the pointer field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Pointer() uint8 {
+ return (*ts)[IPv4OptTSPointerOffset]
+}
+
+// Flags returns the flags field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags {
+ return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask)
+}
+
+// Overflow returns the Overflow field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Overflow() uint8 {
+ return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift
+}
+
+// IncOverflow increments the Overflow field in the IP Timestamp option. It
+// returns the incremented value. If the return value is 0 then the field
+// overflowed.
+func (ts *IPv4OptionTimestamp) IncOverflow() uint8 {
+ (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift
+ return ts.Overflow()
+}
+
+// UpdateTimestamp updates the fields of the next free timestamp slot.
+func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) {
+ slot := (*ts)[ts.Pointer()-1:]
+
+ switch ts.Flags() {
+ case IPv4OptionTimestampOnlyFlag:
+ binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize
+ case IPv4OptionTimestampWithIPFlag:
+ if n := copy(slot, addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ case IPv4OptionTimestampWithPredefinedIPFlag:
+ if tcpip.Address(slot[:IPv4AddressSize]) == addr {
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ }
+ }
+}
+
+// RecordRoute option specific related constants.
+//
+// from RFC 791 page 20:
+// Record Route
+//
+// +--------+--------+--------+---------//--------+
+// |00000111| length | pointer| route data |
+// +--------+--------+--------+---------//--------+
+// Type=7
+//
+// The record route option provides a means to record the route of
+// an internet datagram.
+//
+// The option begins with the option type code. The second octet
+// is the option length which includes the option type code and the
+// length octet, the pointer octet, and length-3 octets of route
+// data. The third octet is the pointer into the route data
+// indicating the octet which begins the next area to store a route
+// address. The pointer is relative to this option, and the
+// smallest legal value for the pointer is 4.
+const (
+ // IPv4OptionRecordRouteHdrLength is the length of the Record Route option
+ // header.
+ IPv4OptionRecordRouteHdrLength = 3
+
+ // IPv4OptRRPointerOffset is the offset to the pointer field in an RR
+ // option, which points to the next free slot in the list of addresses.
+ IPv4OptRRPointerOffset = 2
+)
+
+var _ IPv4Option = (*IPv4OptionRecordRoute)(nil)
+
+// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791.
+type IPv4OptionRecordRoute []byte
+
+// Pointer returns the pointer field in the IP RecordRoute option.
+func (rr *IPv4OptionRecordRoute) Pointer() uint8 {
+ return (*rr)[IPv4OptRRPointerOffset]
+}
+
+// StoreAddress stores the given IPv4 address into the next free slot.
+func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) {
+ start := rr.Pointer() - 1 // A one based number.
+ // start and room checked by caller.
+ if n := copy((*rr)[start:], addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize
+}
+
+// Type implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType }
+
+// Size implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
+
+// Contents implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index c5d8a3456..4e7e5f76a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -101,8 +101,10 @@ const (
// The address is ff02::2.
IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
- // section 5.
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
+ // section 5:
+ // IPv6 requires that every link in the Internet have an MTU of 1280 octets
+ // or greater. This is known as the IPv6 minimum link MTU.
IPv6MinimumMTU = 1280
// IPv6Loopback is the IPv6 Loopback address.
@@ -373,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
+// address.
+func IsV6LoopbackAddress(addr tcpip.Address) bool {
+ return addr == IPv6Loopback
+}
+
// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
// link-local multicast address.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index dc239a0d0..2777f1411 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) {
const count = 1000000
var wg sync.WaitGroup
+ defer wg.Wait()
wg.Add(1)
go func() {
defer wg.Done()
@@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) {
}
}()
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
}
-
- rx.Flush()
}
- }()
- wg.Wait()
+ rx.Flush()
+ }
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 560477926..b3e8c4b92 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
//
// We don't clone the original packet buffer so that the new packet buffer
// does not have any of its headers set.
- pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
+ //
+ // We trim the link headers from the cloned buffer as the sniffer doesn't
+ // handle link headers.
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ vv.TrimFront(len(pkt.LinkHeader().View()))
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
switch protocol {
case header.IPv4ProtocolNumber:
if ok := parse.IPv4(pkt); !ok {
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 0243424f6..86f14db76 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "tun_endpoint_refs.go",
package = "tun",
prefix = "tunEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tunEndpoint",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index f94491026..cda6328a2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// 2. Creating a new NIC.
id := tcpip.NICID(s.UniqueID())
- // TODO(gvisor.dev/1486): enable leak check for tunEndpoint.
endpoint := &tunEndpoint{
Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
stack: s,
@@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
name: name,
isTap: prefix == "tap",
}
+ endpoint.EnableLeakCheck()
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index b40dde96b..8a6bcfc2c 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -30,5 +30,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7df77c66e..33a4a0720 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -18,6 +18,7 @@
package arp
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -121,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
@@ -144,34 +145,43 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
remoteAddr := tcpip.Address(h.ProtocolAddressSender())
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- // As per RFC 826, under Packet Reception:
- // Swap hardware and protocol fields, putting the local hardware and
- // protocol addresses in the sender fields.
- //
- // Send the packet to the (new) target hardware address on the same
- // hardware on which the request was received.
- origSender := h.HardwareAddressSender()
- r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
+ respPkt.NetworkProtocolNumber = ProtocolNumber
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), origSender)
- copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress())
+ if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ origSender := h.HardwareAddressSender()
+ if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize))
+ }
+ if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -199,6 +209,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
+ stack *stack.Stack
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -227,26 +238,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- r := &stack.Route{
- NetProto: ProtocolNumber,
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ if len(remoteLinkAddr) == 0 {
+ remoteLinkAddr = header.EthernetBroadcastAddress
}
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+
+ nicID := nic.ID()
+ if len(localAddr) == 0 {
+ addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ if len(addr.Address) == 0 {
+ return tcpip.ErrNetworkUnreachable
+ }
+
+ localAddr = addr.Address
+ } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return tcpip.ErrBadLocalAddress
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ pkt.NetworkProtocolNumber = ProtocolNumber
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), linkEP.LinkAddress())
- copy(h.ProtocolAddressSender(), localAddr)
- copy(h.ProtocolAddressTarget(), addr)
-
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
@@ -286,6 +315,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
// address must be added to every NIC that should respond to ARP requests. See
// ProtocolAddress for more details.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 626af975a..087ee9c66 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -78,13 +79,11 @@ func (t eventType) String() string {
type eventInfo struct {
eventType eventType
nicID tcpip.NICID
- addr tcpip.Address
- linkAddr tcpip.LinkAddress
- state stack.NeighborState
+ entry stack.NeighborEntry
}
func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
}
// arpDispatcher implements NUDDispatcher to validate the dispatching of
@@ -96,35 +95,29 @@ type arpDispatcher struct {
var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
-func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryChanged,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryRemoved,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
@@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address,
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
}
case <-ctx.Done():
@@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
wantEvent := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: test.senderAddr,
- linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- state: stack.Stale,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ State: stack.Stale,
+ },
}
if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
t.Fatal(err)
@@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
}
- if got, want := neigh.LocalAddr, stackAddr; got != want {
- t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
- }
if got, want := neigh.State, stack.Stale; got != want {
t.Errorf("got neighbor State = %s, want = %s", got, want)
}
@@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.LinkEndpoint
+
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
+ testAddr := tcpip.Address([]byte{1, 2, 3, 4})
+
tests := []struct {
name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Unicast with no local address available",
remoteLinkAddr: remoteLinkAddr,
- expectLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
{
- name: "Multicast",
+ name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectLinkAddr: header.EthernetBroadcastAddress,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
for _, test := range tests {
- p := arp.NewProtocol(nil)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(arp.ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
- }
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ }
+ }
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
- }
+ // We pass a test network interface to LinkAddressRequest with the same
+ // NIC ID and link endpoint used by the NIC we created earlier so that we
+ // can mock a link address request and observe the packets sent to the
+ // link endpoint even though the stack uses the real NIC to validate the
+ // local address.
+ if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ }
+
+ rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index ed502a473..936601287 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
+//
+// releaseCB is a callback that will run when the fragment reassembly of a
+// packet is complete or cancelled. releaseCB take a a boolean argument which is
+// true iff the reassembly is cancelled due to timeout. releaseCB should be
+// passed only with the first fragment of a packet. If more than one releaseCB
+// are passed for the same packet, only the first releaseCB will be saved for
+// the packet and the succeeding ones will be dropped by running them
+// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -171,6 +179,12 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
+ if releaseCB != nil {
+ if !r.setCallback(releaseCB) {
+ // We got a duplicate callback. Release it immediately.
+ releaseCB(false /* timedOut */)
+ }
+ }
f.mu.Unlock()
res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
@@ -178,14 +192,14 @@ func (f *Fragmentation) Process(
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
- f.release(r)
+ f.release(r, false /* timedOut */)
f.mu.Unlock()
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
if done {
- f.release(r)
+ f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
@@ -195,14 +209,14 @@ func (f *Fragmentation) Process(
if tail == nil {
break
}
- f.release(tail)
+ f.release(tail, false /* timedOut */)
}
}
f.mu.Unlock()
return res, firstFragmentProto, done, nil
}
-func (f *Fragmentation) release(r *reassembler) {
+func (f *Fragmentation) release(r *reassembler, timedOut bool) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
@@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) {
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
f.size = 0
}
+
+ r.release(timedOut) // releaseCB may run.
}
// releaseReassemblersLocked releases already-expired reassemblers, then
@@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() {
break
}
// If the oldest reassembler has already expired, release it.
- f.release(r)
+ f.release(r, true /* timedOut*/)
}
}
// PacketFragmenter is the book-keeping struct for packet fragmentation.
type PacketFragmenter struct {
- transportHeader buffer.View
- data buffer.VectorisedView
- reserve int
- innerMTU int
- fragmentCount int
- currentFragment int
- fragmentOffset int
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ fragmentPayloadLen int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
}
// MakePacketFragmenter prepares the struct needed for packet fragmentation.
//
// pkt is the packet to be fragmented.
//
-// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
// have.
//
// reserve is the number of bytes that should be reserved for the headers in
// each generated fragment.
-func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
// As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
// repeated in each fragment. However we do not currently support any header
// of that kind yet, so the following computation is valid for both IPv4 and
@@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa
var fragmentableData buffer.VectorisedView
fragmentableData.AppendView(pkt.TransportHeader().View())
fragmentableData.Append(pkt.Data)
- fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+ fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
return PacketFragmenter{
- data: fragmentableData,
- reserve: reserve,
- innerMTU: innerMTU,
- fragmentCount: fragmentCount,
+ data: fragmentableData,
+ reserve: reserve,
+ fragmentPayloadLen: int(fragmentPayloadLen),
+ fragmentCount: int(fragmentCount),
}
}
@@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int,
})
// Copy data for the fragment.
- copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
offset := pf.fragmentOffset
pf.fragmentOffset += copied
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index d3c7d7f92..5dcd10730 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) {
f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
if err != nil {
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
@@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) {
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) {
func TestMemoryLimits(t *testing.T) {
f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) {
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
got := f.size
want := 1
@@ -377,7 +377,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) {
tests := []struct {
name string
- innerMTU int
+ fragmentPayloadLen uint32
transportHeaderLen int
payloadSize int
wantFragments []fragmentInfo
}{
{
name: "Packet exactly fits in MTU",
- innerMTU: 1280,
+ fragmentPayloadLen: 1280,
transportHeaderLen: 0,
payloadSize: 1280,
wantFragments: []fragmentInfo{
@@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet exactly does not fit in MTU",
- innerMTU: 1000,
+ fragmentPayloadLen: 1000,
transportHeaderLen: 0,
payloadSize: 1001,
wantFragments: []fragmentInfo{
@@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a transport header",
- innerMTU: 560,
+ fragmentPayloadLen: 560,
transportHeaderLen: 40,
payloadSize: 560,
wantFragments: []fragmentInfo{
@@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a huge transport header",
- innerMTU: 500,
+ fragmentPayloadLen: 500,
transportHeaderLen: 1300,
payloadSize: 500,
wantFragments: []fragmentInfo{
@@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) {
originalPayload.AppendView(pkt.TransportHeader().View())
originalPayload.Append(pkt.Data)
var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
for i := 0; ; i++ {
fragPkt, offset, copied, more := pf.BuildNextFragment()
wantFragment := test.wantFragments[i]
@@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) {
if more != wantFragment.more {
t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
}
- if got := fragPkt.Size(); got > test.innerMTU {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
}
if got := fragPkt.AvailableHeaderBytes(); got != reserve {
t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
@@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) {
})
}
}
+
+func TestReleaseCallback(t *testing.T) {
+ const (
+ proto = 99
+ )
+
+ var result int
+ var callbackReasonIsTimeout bool
+ cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+
+ tests := []struct {
+ name string
+ callbacks []func(bool)
+ timeout bool
+ wantResult int
+ wantCallbackReasonIsTimeout bool
+ }{
+ {
+ name: "callback runs on release",
+ callbacks: []func(bool){cb1},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "first callback is nil",
+ callbacks: []func(bool){nil, cb2},
+ timeout: false,
+ wantResult: 2,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "two callbacks - first one is set",
+ callbacks: []func(bool){cb1, cb2},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "callback runs on timeout",
+ callbacks: []func(bool){cb1},
+ timeout: true,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: true,
+ },
+ {
+ name: "no callbacks",
+ callbacks: []func(bool){nil},
+ timeout: false,
+ wantResult: 0,
+ wantCallbackReasonIsTimeout: false,
+ },
+ }
+
+ id := FragmentID{ID: 0}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ result = 0
+ callbackReasonIsTimeout = false
+
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+
+ for i, cb := range test.callbacks {
+ _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
+ if err != nil {
+ t.Errorf("f.Process error = %s", err)
+ }
+ }
+
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatalf("Reassemberr not found")
+ }
+ f.release(r, test.timeout)
+
+ if result != test.wantResult {
+ t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ }
+ if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
+ t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9bb051a30..c0cc0bde0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -41,6 +41,7 @@ type reassembler struct {
heap fragHeap
done bool
creationTime int64
+ callback func(bool)
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
@@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
+
+func (r *reassembler) setCallback(c func(bool)) bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.callback != nil {
+ return false
+ }
+ r.callback = c
+ return true
+}
+
+func (r *reassembler) release(timedOut bool) {
+ r.mu.Lock()
+ callback := r.callback
+ r.callback = nil
+ r.mu.Unlock()
+
+ if callback != nil {
+ callback(timedOut)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..fa2a70dc8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
+
+func TestSetCallback(t *testing.T) {
+ result := 0
+ reasonTimeout := false
+
+ cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
+
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ if !r.setCallback(cb1) {
+ t.Errorf("setCallback failed")
+ }
+ if r.setCallback(cb2) {
+ t.Errorf("setCallback should fail if one is already set")
+ }
+ r.release(true)
+ if result != 1 {
+ t.Errorf("got result = %d, want = 1", result)
+ }
+ if !reasonTimeout {
+ t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f20b94d97..8873bd91f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
- t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
+ netHdr := pkt.Network()
+ t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress())
t.dataCalls++
return stack.TransportPacketHandled
}
@@ -304,6 +305,10 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -604,7 +609,8 @@ func TestIPv4Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -690,6 +696,10 @@ func TestIPv4ReceiveControl(t *testing.T) {
view[i] = uint8(i)
}
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
+
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
nic.testObject.protocol = 10
@@ -699,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -780,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
@@ -792,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -892,7 +906,8 @@ func TestIPv6Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -1009,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -1063,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum tcpip.NetworkProtocolNumber
nicAddr tcpip.Address
remoteAddr tcpip.Address
- pktGen func(*testing.T, tcpip.Address) buffer.View
+ pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
expectedErr *tcpip.Error
}{
@@ -1073,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1087,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1115,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1129,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1139,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1148,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1158,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1167,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1195,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
@@ -1213,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
}
- return hdr.View()
+ return hdr.View().ToVectorisedView()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options and data across views",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ vv := buffer.View(ip).ToVectorisedView()
+ vv.AppendView(ipv4Options)
+ vv.AppendView(data)
+ return vv
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1243,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1256,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1283,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1299,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1326,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1334,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1361,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1369,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1413,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
defer r.Release()
if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ Data: test.pktGen(t, subTest.srcAddr),
})); err != test.expectedErr {
t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7fc12e229..6252614ec 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -29,6 +29,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3407755ed..9b5e37fee 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -23,10 +24,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// handleControl handles the case when an ICMP packet contains the headers of
-// the original packet that caused the ICMP one to be sent. This information is
-// used to find out which transport endpoint must be notified about the ICMP
-// packet.
+// handleControl handles the case when an ICMP error packet contains the headers
+// of the original packet that caused the ICMP one to be sent. This information
+// is used to find out which transport endpoint must be notified about the ICMP
+// packet. We only expect the payload, not the enclosing ICMP packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
@@ -41,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
- src := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
+ srcAddr := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 {
return
}
@@ -57,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
- stats := r.Stats()
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
received := stats.ICMP.V4PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
@@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
h := header.ICMPv4(v)
+ // Only do in-stack processing if the checksum is correct.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
+ received.Invalid.Increment()
+ // It's possible that a raw socket expects to receive this regardless
+ // of checksum errors. If it's an echo request we know it's safe because
+ // we are the only handler, however other types do not cope well with
+ // packets with checksum errors.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
+ }
+ return
+ }
+
+ iph := header.IPv4(pkt.NetworkHeader().View())
+ var newOptions header.IPv4Options
+ if len(iph) > header.IPv4MinimumSize {
+ // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
+ // type ICMP packets):
+ // If a Record Route and/or Time Stamp option is received in an
+ // ICMP Echo Request, this option (these options) SHOULD be
+ // updated to include the current host and included in the IP
+ // header of the Echo Reply message, without "truncation".
+ // Thus, the recorded route will be for the entire round trip.
+ //
+ // So we need to let the option processor know how it should handle them.
+ var op optionsUsage
+ if h.Type() == header.ICMPv4Echo {
+ op = &optionUsageEcho{}
+ } else {
+ op = &optionUsageReceive{}
+ }
+ aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ newOptions = tmp
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
received.Echo.Increment()
- // Only send a reply if the checksum is valid.
- headerChecksum := h.Checksum()
- h.SetChecksum(0)
- calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- h.SetChecksum(headerChecksum)
- if calculatedChecksum != headerChecksum {
- // It's possible that a raw socket still expects to receive this.
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- received.Invalid.Increment()
+ sent := stats.ICMP.V4PacketsSent
+ if !e.protocol.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return
}
@@ -98,19 +144,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
replyData := pkt.Data.ToOwnedView()
- replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
+
+ // It's possible that a raw socket expects to receive this.
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
+ pkt = nil
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ // Take the base of the incoming request IP header but replace the options.
+ replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
+ replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...))
+ replyIPHdr.SetHeaderLength(replyHeaderLength)
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
- localAddr := r.LocalAddress
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
+ localAddr := ipHdr.DestinationAddress()
+ if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -139,7 +193,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// The fields we need to alter.
//
// We need to produce the entire packet in the data segment in order to
- // use WriteHeaderIncludedPacket().
+ // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the
+ // total length and the header checksum so we don't need to set those here.
replyIPHdr.SetSourceAddress(r.LocalAddress)
replyIPHdr.SetDestinationAddress(r.RemoteAddress)
replyIPHdr.SetTTL(r.DefaultTTL())
@@ -157,8 +212,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
})
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // The checksum will be calculated so we don't need to do it here.
- sent := stats.ICMP.V4PacketsSent
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
@@ -168,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
@@ -182,8 +235,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
}
case header.ICMPv4SrcQuench:
@@ -234,12 +290,31 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
+// icmpReasonParamProblem is an error to use to request a Parameter Problem
+// message to be sent.
+type icmpReasonParamProblem struct {
+ pointer byte
+}
+
+func (*icmpReasonParamProblem) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv4(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -263,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
//
// TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
// response to a non-initial fragment, but it currently can not happen.
-
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any {
return nil
}
@@ -272,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
sent := p.stack.Stats().ICMP.V4PacketsSent
if !p.stack.AllowICMPMessage() {
@@ -287,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
- if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
// TODO(gvisor.dev/issue/3810):
// Unfortunately the current stack pretty much always has ICMPv4 headers
// in the Data section of the packet but there is no guarantee that is the
@@ -348,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -360,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// view with the entire incoming IP packet reassembled and truncated as
// required. This is now the payload of the new ICMP packet and no longer
// considered a packet in its own right.
- newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader := append(buffer.View(nil), origIPHdr...)
newHeader = append(newHeader, transportHeader...)
payload := newHeader.ToVectorisedView()
payload.AppendView(pkt.Data.ToView())
@@ -374,17 +444,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- switch reason.(type) {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ counter = sent.DstUnreachable
case *icmpReasonProtoUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
+ counter = sent.TimeExceeded
+ case *icmpReasonParamProblem:
+ icmpHdr.SetType(header.ICMPv4ParamProblem)
+ icmpHdr.SetCode(header.ICMPv4UnusedCode)
+ icmpHdr.SetPointer(reason.pointer)
+ counter = sent.ParamProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
- counter := sent.DstUnreachable
if err := route.WritePacket(
nil, /* gso */
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e7c58ae0a..cfd0c505a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -16,7 +16,9 @@
package ipv4
import (
+ "errors"
"fmt"
+ "math"
"sync/atomic"
"time"
@@ -31,6 +33,8 @@ import (
)
const (
+ // ReassembleTimeout is the time a packet stays in the reassembly
+ // system before being evicted.
// As per RFC 791 section 3.2:
// The current recommendation for the initial timer setting is 15 seconds.
// This may be changed as experience with this protocol accumulates.
@@ -38,7 +42,7 @@ const (
// Considering that it is an old recommendation, we use the same reassembly
// timeout that linux defines, which is 30 seconds:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
- reassembleTimeout = 30 * time.Second
+ ReassembleTimeout = 30 * time.Second
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
@@ -211,18 +219,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
-func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
-}
-
// handleFragments fragments pkt and calls the handler function on each
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
-// original packet. The mtu is the maximum size of the packets.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+// original packet.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ // Round the MTU down to align to 8 bytes.
+ fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader))
var n int
for {
@@ -247,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
@@ -265,23 +269,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, pkt)
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -292,6 +313,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
+
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -311,17 +333,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.addIPHeader(r, pkt, params)
- if e.packetMustBeFragmented(pkt, gso) {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
return nil
}); err != nil {
- panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err))
+ panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err))
}
// Remove the packet that was just fragmented and process the rest.
pkts.Remove(originalPkt)
@@ -355,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -385,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return tcpip.ErrMalformedHeader
}
+
+ hdrLen := header.IPv4(h).HeaderLength()
+ if hdrLen < header.IPv4MinimumSize {
+ return tcpip.ErrMalformedHeader
+ }
+
+ h, ok = pkt.Data.PullUp(int(hdrLen))
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
ip := header.IPv4(h)
// Always set the total length.
@@ -429,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -462,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -470,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -480,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -488,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -502,10 +545,30 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
+
+ // Set up a callback in case we need to send a Time Exceeded Message, as per
+ // RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards
+ // the datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+ }
+ }
+
var ready bool
var err error
proto := h.Protocol()
@@ -523,29 +586,56 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
h.More(),
proto,
pkt.Data,
+ releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if !ready {
return
}
+
+ // The reassembler doesn't take care of fixing up the header, so we need
+ // to do it here.
+ h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
+ h.SetFlagsFragmentOffset(0, 0)
}
+ stats.IP.PacketsDelivered.Increment()
- r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
// TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
// headers, the setting of the transport number here should be
// unnecessary and removed.
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt)
+ e.handleICMP(pkt)
return
}
+ if len(h.Options()) != 0 {
+ // TODO(gvisor.dev/issue/4586):
+ // When we add forwarding support we should use the verified options
+ // rather than just throwing them away.
+ aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ }
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
@@ -553,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC: 1122 Section 3.2.2.1
// A host SHOULD generate Destination Unreachable messages with code:
// 2 (Protocol Unreachable), when the designated transport protocol
// is not supported
- _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -602,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
loopback := e.nic.IsLoopback()
addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
@@ -778,26 +868,32 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload mtu.
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv4MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
}
- return mtu - header.IPv4MinimumSize
-}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+ // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
+ // length:
+ // The maximal internet header is 60 octets, and a typical internet header
+ // is 20 octets, allowing a margin for headers of higher level protocols.
+ if networkHeaderSize > header.IPv4MaximumHeaderSize {
+ return 0, tcpip.ErrMalformedHeader
}
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- return mtu
+
+ networkMTU := linkMTU
+ if networkMTU > MaxTotalSize {
+ networkMTU = MaxTotalSize
+ }
+
+ return networkMTU - uint32(networkHeaderSize), nil
+}
+
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// addressToUint32 translates an IPv4 address into its little endian uint32
@@ -836,7 +932,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
}
}
@@ -846,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
originalIPHeaderLength := len(originalIPHeader)
nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
@@ -862,3 +959,324 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
return fragPkt, more
}
+
+// optionAction describes possible actions that may be taken on an option
+// while processing it.
+type optionAction uint8
+
+const (
+ // optionRemove says that the option should not be in the output option set.
+ optionRemove optionAction = iota
+
+ // optionProcess says that the option should be fully processed.
+ optionProcess
+
+ // optionVerify says the option should be checked and passed unchanged.
+ optionVerify
+
+ // optionPass says to pass the output set without checking.
+ optionPass
+)
+
+// optionActions list what to do for each option in a given scenario.
+type optionActions struct {
+ // timestamp controls what to do with a Timestamp option.
+ timestamp optionAction
+
+ // recordroute controls what to do with a Record Route option.
+ recordRoute optionAction
+
+ // unknown controls what to do with an unknown option.
+ unknown optionAction
+}
+
+// optionsUsage specifies the ways options may be operated upon for a given
+// scenario during packet processing.
+type optionsUsage interface {
+ actions() optionActions
+}
+
+// optionUsageReceive implements optionsUsage for received packets.
+type optionUsageReceive struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageReceive) actions() optionActions {
+ return optionActions{
+ timestamp: optionVerify,
+ recordRoute: optionVerify,
+ unknown: optionPass,
+ }
+}
+
+// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it
+// is enabled (Process, Process, Pass) and for fragmenting (Process, Process,
+// Pass for frag1, but Remove,Remove,Remove for all other frags).
+
+// optionUsageEcho implements optionsUsage for echo packet processing.
+type optionUsageEcho struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageEcho) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
+ unknown: optionRemove,
+ }
+}
+
+var (
+ errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
+ errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
+ errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
+ errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
+)
+
+// handleTimestamp does any required processing on a Timestamp option
+// in place.
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+ flags := tsOpt.Flags()
+ var entrySize uint8
+ switch flags {
+ case header.IPv4OptionTimestampOnlyFlag:
+ entrySize = header.IPv4OptionTimestampSize
+ case
+ header.IPv4OptionTimestampWithIPFlag,
+ header.IPv4OptionTimestampWithPredefinedIPFlag:
+ entrySize = header.IPv4OptionTimestampWithAddrSize
+ default:
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ }
+
+ pointer := tsOpt.Pointer()
+ // To simplify processing below, base further work on the array of timestamps
+ // beyond the header, rather than on the whole option. Also to aid
+ // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
+ nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1)
+ optLen := tsOpt.Size()
+ dataLength := optLen - header.IPv4OptionTimestampHdrLength
+
+ // In the section below, we verify the pointer, length and overflow counter
+ // fields of the option. The distinction is in which byte you return as being
+ // in error in the ICMP packet. Offsets 1 (length), 2 pointer)
+ // or 3 (overflowed counter).
+ //
+ // The following RFC sections cover this section:
+ //
+ // RFC 791 (page 22):
+ // If there is some room but not enough room for a full timestamp
+ // to be inserted, or the overflow count itself overflows, the
+ // original datagram is considered to be in error and is discarded.
+ // In either case an ICMP parameter problem message may be sent to
+ // the source host [3].
+ //
+ // You can get this situation in two ways. Firstly if the data area is not
+ // a multiple of the entry size or secondly, if the pointer is not at a
+ // multiple of the entry size. The wording of the RFC suggests that
+ // this is not an error until you actually run out of space.
+ if pointer > optLen {
+ // RFC 791 (page 22) says we should switch to using the overflow count.
+ // If the timestamp data area is already full (the pointer exceeds
+ // the length) the datagram is forwarded without inserting the
+ // timestamp, but the overflow count is incremented by one.
+ if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
+ // By definition we have nothing to do.
+ return 0, nil
+ }
+
+ if tsOpt.IncOverflow() != 0 {
+ return 0, nil
+ }
+ // The overflow count is also full.
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ }
+ if nextSlot+entrySize > dataLength {
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad.
+ if false {
+ // We must select Pointer for Linux compatibility, even if
+ // only the length is bad.
+ // The Linux code is at (in October 2020)
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ // which doesn't distinguish between which of optptr[2] or optlen
+ // is wrong, but just arbitrarily decides on optptr+2.
+ if dataLength%entrySize != 0 {
+ // The Data section size should be a multiple of the expected
+ // timestamp entry size.
+ return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ }
+ // If the size is OK, the pointer must be corrupted.
+ }
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
+
+ if usage.actions().timestamp == optionProcess {
+ tsOpt.UpdateTimestamp(localAddress, clock)
+ }
+ return 0, nil
+}
+
+var (
+ errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
+ errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
+)
+
+// handleRecordRoute checks and processes a Record route option. It is much
+// like the timestamp type 1 option, but without timestamps. The passed in
+// address is stored in the option in the correct spot if possible.
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+ optlen := rrOpt.Size()
+
+ if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+
+ nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+
+ // RFC 791 page 21 says
+ // If the route data area is already full (the pointer exceeds the
+ // length) the datagram is forwarded without inserting the address
+ // into the recorded route. If there is some room but not enough
+ // room for a full address to be inserted, the original datagram is
+ // considered to be in error and is discarded. In either case an
+ // ICMP parameter problem message may be sent to the source
+ // host.
+ // The use of the words "In either case" suggests that a 'full' RR option
+ // could generate an ICMP at every hop after it fills up. We chose to not
+ // do this (as do most implementations). It is probable that the inclusion
+ // of these words is a copy/paste error from the timestamp option where
+ // there are two failure reasons given.
+ if nextSlot >= optlen {
+ return 0, nil
+ }
+
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad. We must select Pointer for Linux
+ // compatibility, even if only the length is bad.
+ if nextSlot+header.IPv4AddressSize > optlen {
+ if false {
+ // This is what we would do if we were not being Linux compatible.
+ // Check for bad pointer or length value. Must be a multiple of 4 after
+ // accounting for the 3 byte header and not within that header.
+ // RFC 791, page 20 says:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ //
+ // A recorded route is composed of a series of internet addresses.
+ // Each internet address is 32 bits or 4 octets.
+ // Linux skips this test so we must too. See Linux code at:
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
+ // Length is bad, not on integral number of slots.
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+ // If not length, the fault must be with the pointer.
+ }
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
+ if usage.actions().recordRoute == optionVerify {
+ return 0, nil
+ }
+ rrOpt.StoreAddress(localAddress)
+ return 0, nil
+}
+
+// processIPOptions parses the IPv4 options and produces a new set of options
+// suitable for use in the next step of packet processing as informed by usage.
+// The original will not be touched.
+//
+// Returns
+// - The location of an error if there was one (or 0 if no error)
+// - If there is an error, information as to what it was was.
+// - The replacement option set.
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+ stats := e.protocol.stack.Stats()
+ opts := header.IPv4Options(orig)
+ optIter := opts.MakeIterator()
+
+ // Each option other than NOP must only appear (RFC 791 section 3.1, at the
+ // definition of every type). Keep track of each of the possible types in
+ // the 8 bit 'type' field.
+ var seenOptions [math.MaxUint8 + 1]bool
+
+ // TODO(gvisor.dev/issue/4586):
+ // This will need tweaking when we start really forwarding packets
+ // as we may need to get two addresses, for rx and tx interfaces.
+ // We will also have to take usage into account.
+ prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
+ localAddress := prefixedAddress.Address
+ if err != nil {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
+ return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ }
+ localAddress = dstAddr
+ }
+
+ for {
+ option, done, err := optIter.Next()
+ if done || err != nil {
+ return optIter.ErrCursor, optIter.Finalize(), err
+ }
+ optType := option.Type()
+ if optType == header.IPv4OptionNOPType {
+ optIter.PushNOPOrEnd(optType)
+ continue
+ }
+ if optType == header.IPv4OptionListEndType {
+ optIter.PushNOPOrEnd(optType)
+ return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ }
+
+ // check for repeating options (multiple NOPs are OK)
+ if seenOptions[optType] {
+ return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ }
+ seenOptions[optType] = true
+
+ optLen := int(option.Size())
+ switch option := option.(type) {
+ case *header.IPv4OptionTimestamp:
+ stats.IP.OptionTSReceived.Increment()
+ if usage.actions().timestamp != optionRemove {
+ clock := e.protocol.stack.Clock()
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ case *header.IPv4OptionRecordRoute:
+ stats.IP.OptionRRReceived.Increment()
+ if usage.actions().recordRoute != optionRemove {
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ default:
+ stats.IP.OptionUnknownReceived.Increment()
+ if usage.actions().unknown == optionPass {
+ newBuffer := optIter.RemainingBuffer()[:optLen]
+ // Arguments already heavily checked.. ignore result.
+ _ = copy(newBuffer, option.Contents())
+ optIter.ConsumeBuffer(optLen)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index fee11bb38..c7f434591 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -21,11 +21,13 @@ import (
"math"
"net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -39,7 +41,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-const extraHeaderReserve = 50
+const (
+ extraHeaderReserve = 50
+ defaultMTU = 65536
+)
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
@@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
ep = sniffer.New(ep)
@@ -103,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) {
// checks the response.
func TestIPv4Sanity(t *testing.T) {
const (
- defaultMTU = header.IPv6MinimumMTU
ttl = 255
nicID = 1
randomSequence = 123
@@ -118,27 +121,29 @@ func TestIPv4Sanity(t *testing.T) {
)
tests := []struct {
- name string
- headerLength uint8 // value of 0 means "use correct size"
- badHeaderChecksum bool
- maxTotalLength uint16
- transportProtocol uint8
- TTL uint8
- shouldFail bool
- expectICMP bool
- ICMPType header.ICMPv4Type
- ICMPCode header.ICMPv4Code
- options []byte
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ badHeaderChecksum bool
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ options []byte
+ replyOptions []byte // if succeeds, reply should look like this
+ shouldFail bool
+ expectErrorICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ paramProblemPointer uint8
}{
{
- name: "valid",
- maxTotalLength: defaultMTU,
+ name: "valid no options",
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
},
{
name: "bad header checksum",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
badHeaderChecksum: true,
@@ -157,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) {
// received with TTL less than 2.
{
name: "zero TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 0,
- shouldFail: false,
},
{
name: "one TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 1,
- shouldFail: false,
},
{
name: "End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{0, 0, 0, 0},
+ replyOptions: []byte{0, 0, 0, 0},
},
{
name: "NOP options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 1, 1},
+ replyOptions: []byte{1, 1, 1, 1},
},
{
name: "NOP and End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 0, 0},
+ replyOptions: []byte{1, 1, 0, 0},
},
{
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (0)",
@@ -205,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip - 1)",
@@ -213,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip + icmp - 1)",
@@ -221,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad protocol",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: 99,
TTL: ttl,
shouldFail: true,
- expectICMP: true,
+ expectErrorICMP: true,
ICMPType: header.ICMPv4DstUnreachable,
ICMPCode: header.ICMPv4ProtoUnreachable,
},
+ {
+ name: "timestamp option overflow",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ },
+ {
+ name: "timestamp option overflow full",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0xF1,
+ // ^ Counter full (15/0xF)
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ replyOptions: []byte{},
+ },
+ {
+ name: "unknown option",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{10, 4, 9, 0},
+ // ^^
+ // The unknown option should be stripped out of the reply.
+ replyOptions: []byte{},
+ },
+ {
+ name: "bad option - length 0",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 0, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ name: "bad option - length big",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 9, 9, 0,
+ // ^
+ // There are only 8 bytes allocated to options so 9 bytes of timestamp
+ // space is not possible. (Second byte)
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ // This tests for some linux compatible behaviour.
+ // The ICMP pointer returned is 22 for Linux but the
+ // error is actually in spot 21.
+ name: "bad option - length bad",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ // Timestamps are in multiples of 4 or 8 but never 7.
+ // The option space should be padded out.
+ options: []byte{
+ 68, 7, 5, 0,
+ // ^ ^ Linux points here which is wrong.
+ // | Not a multiple of 4
+ 1, 2, 3,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "multiple type 0 with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 24, 21, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 24, 25, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // The timestamp area is full so add to the overflow count.
+ name: "multiple type 1 timestamps",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 21, 0x11,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ // Overflow count is the top nibble of the 4th byte.
+ replyOptions: []byte{
+ 68, 20, 21, 0x21,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ },
+ {
+ name: "multiple type 1 timestamps with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 28, 21, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 28, 29, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 192, 168, 1, 58, // New IP Address.
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
+ name: "bad timer element alignment",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 17, 0x01,
+ // ^^ ^^ 20 byte area, next free spot at 17.
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ // End of option list with illegal option after it, which should be ignored.
+ {
+ name: "end of options list",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 10, 3, 99,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0, // 3 bytes unknown option
+ }, // ^ End of options hides following bytes.
+ },
+ {
+ // Timestamp with a size too small.
+ name: "timestamp truncated",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{68, 1, 0, 0},
+ // ^ Smallest possible is 8.
+ shouldFail: true,
+ },
+ {
+ name: "single record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 4, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "multiple record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 20, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "single record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Unlike timestamp, this should just succeed.
+ name: "multiple record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 24, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm linux bug for bug compatibility.
+ // Linux returns slot 22 but the error is in slot 21.
+ name: "multiple record route with not enough room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 8, 8, // 3 byte header
+ // ^ ^ Linux points here. We must too.
+ // | Not enough room. 1 byte free, need 4.
+ 1, 2, 3, 4,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ replyOptions: []byte{},
+ },
+ {
+ name: "duplicate record route",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, 0, // pad
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 7,
+ replyOptions: []byte{},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ Clock: clock,
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e := channel.New(1, defaultMTU, "")
+ e := channel.New(1, ipv4.MaxTotalSize, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -250,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) {
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
+ // Advance the clock by some unimportant amount to make
+ // sure it's all set up.
+ clock.Advance(time.Millisecond * 0x10203040)
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
@@ -312,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) {
reply, ok := e.Read()
if !ok {
if test.shouldFail {
- if test.expectICMP {
- t.Fatal("expected ICMP error response missing")
+ if test.expectErrorICMP {
+ t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
}
return // Expected silent failure.
}
t.Fatal("expected ICMP echo reply missing")
}
+ // We didn't expect a packet. Register our surprise but carry on to
+ // provide more information about what we got.
+ if test.shouldFail && !test.expectErrorICMP {
+ t.Error("unexpected packet response")
+ }
+
// Check the route that brought the packet to us.
if reply.Route.LocalAddress != ipv4Addr.Address {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
@@ -328,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
}
- // Make sure it's all in one buffer.
- vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
- replyIPHeader := header.IPv4(vv.ToView())
+ // Make sure it's all in one buffer for checker.
+ replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
- // At this stage we only know it's an IP header so verify that much.
+ // At this stage we only know it's probably an IP+ICMP header so verify
+ // that much.
checker.IPv4(t, replyIPHeader,
checker.SrcAddr(ipv4Addr.Address),
checker.DstAddr(remoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ ),
)
- // All expected responses are ICMP packets.
- if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
- t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ // Don't proceed any further if the checker found problems.
+ if t.Failed() {
+ t.FailNow()
}
- replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
- // Sanity check the response.
+ // OK it's ICMP. We can safely look at the type now.
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
switch replyICMPHeader.Type() {
- case header.ICMPv4DstUnreachable:
+ case header.ICMPv4ParamProblem:
+ if !test.shouldFail {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
+ }
checker.IPv4(t, replyIPHeader,
checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
checker.IPv4HeaderLength(header.IPv4MinimumSize),
checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Checksum(),
+ checker.ICMPv4Pointer(test.paramProblemPointer),
checker.ICMPv4Payload([]byte(hdr.View())),
),
)
- if !test.shouldFail || !test.expectICMP {
- t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ return
+ case header.ICMPv4DstUnreachable:
+ if !test.shouldFail {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
header.ICMPv4DstUnreachable, replyICMPHeader.Code())
}
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
return
case header.ICMPv4EchoReply:
+ if test.shouldFail {
+ if !test.expectErrorICMP {
+ t.Error("got Echo Reply packet, want no response")
+ } else {
+ t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
+ }
+ }
+ // If the IP options change size then the packet will change size, so
+ // some IP header fields will need to be adjusted for the checks.
+ sizeChange := len(test.replyOptions) - len(test.options)
+
checker.IPv4(t, replyIPHeader,
- checker.IPv4HeaderLength(ipHeaderLength),
- checker.IPv4Options(test.options),
- checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
+ checker.IPv4Options(test.replyOptions),
+ checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
checker.ICMPv4(
+ checker.ICMPv4Checksum(),
checker.ICMPv4Code(header.ICMPv4UnusedCode),
checker.ICMPv4Seq(randomSequence),
checker.ICMPv4Ident(randomIdent),
- checker.ICMPv4Checksum(),
),
)
- if test.shouldFail {
- t.Fatalf("unexpected Echo Reply packet\n")
- }
default:
- t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
- replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
}
})
}
@@ -462,7 +840,7 @@ var fragmentationTests = []struct {
wantFragments []fragmentInfo
}{
{
- description: "No Fragmentation",
+ description: "No fragmentation",
mtu: 1280,
gso: nil,
transportHeaderLength: 0,
@@ -483,6 +861,30 @@ var fragmentationTests = []struct {
},
},
{
+ description: "Fragmented with the minimum mtu",
+ mtu: header.IPv4MinimumMTU,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv4MinimumMTU + 1,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
description: "No fragmentation with big header",
mtu: 2000,
gso: nil,
@@ -647,43 +1049,50 @@ func TestFragmentationWritePackets(t *testing.T) {
}
}
-// TestFragmentationErrors checks that errors are returned from write packet
+// TestFragmentationErrors checks that errors are returned from WritePacket
// correctly.
func TestFragmentationErrors(t *testing.T) {
const ttl = 42
- expectedError := tcpip.ErrAborted
- fragTests := []struct {
+ tests := []struct {
description string
mtu uint32
transportHeaderLength int
payloadSize int
allowPackets int
- fragmentCount int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
}{
{
description: "No frag",
mtu: 2000,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 1,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 3,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on second frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 1,
- fragmentCount: 3,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag MTU smaller than header",
@@ -691,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 1000,
payloadSize: 500,
allowPackets: 0,
- fragmentCount: 4,
+ outgoingErrors: 4,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error when MTU is smaller than IPv4 minimum MTU",
+ mtu: header.IPv4MinimumMTU - 1,
+ transportHeaderLength: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
},
}
- for _, ft := range fragTests {
+ for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
- r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != expectedError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
}
- if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
}
- if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
}
})
}
@@ -744,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) {
autoChecksum bool // if true, the Checksum field will be overwritten.
}
- // These packets have both IHL and TotalLength set to 0.
tests := []struct {
name string
fragments []fragmentData
@@ -984,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
-
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -1027,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) {
}
}
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 99
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ },
+ Clock: clock,
+ })
+ e := channel.New(1, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(header.IPv4ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ipv4.ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ ),
+ )
+ })
+ }
+}
+
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
@@ -1506,13 +2178,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
@@ -1527,17 +2196,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
@@ -1577,7 +2243,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -1783,7 +2449,7 @@ func TestPacketQueing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e := channel.New(1, defaultMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index a30437f02..0ac24a6fb 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -36,6 +36,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index ead6bedcb..8502b848c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
})
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
- stats := r.Stats().ICMP
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
+ stats := e.protocol.stack.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
@@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
h := header.ICMPv6(v)
iph := header.IPv6(pkt.NetworkHeader().View())
+ srcAddr := iph.SourceAddress()
+ dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
//
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
payload.TrimFront(len(h))
- if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
received.Invalid.Increment()
return
}
@@ -170,8 +172,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := header.ICMPv6(hdr).MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -221,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// we know we are also performing DAD on it). In this case we let the
// stack know so it can handle such a scenario and do nothing further with
// the NS.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
// We would get an error if the address no longer exists or the address
// is no longer tentative (DAD resolved between the call to
// hasTentativeAddr and this point). Both of these are valid scenarios:
@@ -248,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -274,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// Otherwise, on link layers that have addresses this option MUST be
// included in multicast solicitations and SHOULD be included in unicast
// solicitations.
- unspecifiedSource := r.RemoteAddress == header.IPv6Any
+ unspecifiedSource := srcAddr == header.IPv6Any
if len(sourceLinkAddr) == 0 {
- if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource {
received.Invalid.Increment()
return
}
@@ -284,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
} else if e.nud != nil {
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -295,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// ...
// - If the IP source address is the unspecified address, the IP
// destination address is a solicited-node multicast address.
- if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) {
received.Invalid.Increment()
return
}
@@ -305,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the source of the solicitation is the unspecified address, the node
// MUST [...] and multicast the advertisement to the all-nodes address.
//
- remoteAddr := r.RemoteAddress
+ remoteAddr := srcAddr
if unspecifiedSource {
remoteAddr = header.IPv6AllNodesMulticastAddress
}
@@ -462,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
- localAddr := r.LocalAddress
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr := dstAddr
+ if header.IsV6MulticastAddress(dstAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -483,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -495,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -516,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- stack := r.Stack()
+ stack := e.protocol.stack
// Is the networking stack operating as a router?
if !stack.Forwarding(ProtocolNumber) {
@@ -547,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
// NOT be included when the source IP address is the unspecified address.
// Otherwise, it SHOULD be included on link layers that have addresses.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
received.Invalid.Increment()
return
}
@@ -555,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
if e.nud != nil {
// A RS with a specified source IP address modifies the NUD state
// machine in the same way a reachability probe would.
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
}
@@ -572,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- routerAddr := iph.SourceAddress()
+ routerAddr := srcAddr
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
@@ -605,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
e.mu.Lock()
@@ -648,52 +657,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
- // route here. Note, we would need the nicID to do this properly so the right
- // NIC (associated to linkEP) is used to send the NDP NS message.
- r := stack.Route{
- LocalAddress: localAddr,
- RemoteAddress: addr,
- LocalLinkAddress: linkEP.LinkAddress(),
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ remoteAddr := targetAddr
+ if len(remoteLinkAddr) == 0 {
+ remoteAddr = header.SolicitedNodeAddr(targetAddr)
+ remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- // If a remote address is not already known, then send a multicast
- // solicitation since multicast addresses have a static mapping to link
- // addresses.
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteAddress = header.SolicitedNodeAddr(addr)
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
+ r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
}
+ defer r.Release()
+ r.ResolveWith(remoteLinkAddr)
optsSerializer := header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
ns := header.NDPNeighborSolicit(packet.NDPPayload())
- ns.SetTargetAddress(addr)
+ ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
+ stat := p.stack.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt); err != nil {
+ stat.Dropped.Increment()
+ return err
+ }
- // TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
+ stat.NeighborSolicit.Increment()
+ return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -747,9 +750,20 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv6(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
@@ -776,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -784,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
stats := p.stack.Stats().ICMP
sent := stats.V6PacketsSent
@@ -839,7 +850,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload := network.ToVectorisedView()
+ payload.AppendView(transport)
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -860,6 +873,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
+ counter = sent.TimeExceeded
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 8dc33c560..76013daa1 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -51,6 +51,7 @@ const (
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
+ lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
@@ -86,7 +87,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
return stack.TransportPacketHandled
}
@@ -108,31 +109,27 @@ type stubNUDHandler struct {
var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) {
s.probeCount++
}
-func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
s.confirmationCount++
}
-func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) {
}
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- stack.NetworkLinkEndpoint
-
- linkAddr tcpip.LinkAddress
-}
+ stack.LinkEndpoint
-func (i *testInterface) LinkAddress() tcpip.LinkAddress {
- return i.linkAddr
+ nicID tcpip.NICID
}
func (*testInterface) ID() tcpip.NICID {
- return 0
+ return nicID
}
func (*testInterface) IsLoopback() bool {
@@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool {
return true
}
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -277,7 +282,8 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -419,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -1235,26 +1242,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
snaddr := header.SolicitedNodeAddr(lladdr0)
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectedLinkAddr tcpip.LinkAddress
- expectedAddr tcpip.Address
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedRemoteAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectedLinkAddr: linkAddr1,
- expectedAddr: lladdr0,
+ name: "Unicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectedLinkAddr: mcaddr,
- expectedAddr: snaddr,
+ name: "Multicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Unicast with no local address available",
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
@@ -1269,26 +1322,43 @@ func TestLinkAddressRequest(t *testing.T) {
}
linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
+ }
+
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
}
pkt, ok := linkEP.Read()
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
}
- if pkt.Route.RemoteAddress != test.expectedAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
}
if pkt.Route.LocalAddress != lladdr1 {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
}
checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedAddr),
+ checker.DstAddr(test.expectedRemoteAddr),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(lladdr0),
@@ -1698,7 +1768,7 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
nudHandler := &stubNUDHandler{}
- ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -1728,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) {
SrcAddr: r.RemoteAddress,
DstAddr: r.LocalAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
if nudHandler.probeCount != test.wantProbeCount {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 9670696c7..0526190cc 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -41,12 +41,12 @@ const (
//
// Linux also uses 60 seconds for reassembly timeout:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
- reassembleTimeout = 60 * time.Second
+ ReassembleTimeout = 60 * time.Second
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
- // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // maxPayloadSize is the maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
@@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return err
}
- prefix := addressEndpoint.AddressWithPrefix().Subnet()
+ prefix := addressEndpoint.Subnet()
switch t := addressEndpoint.ConfigType(); t {
case stack.AddressConfigStatic:
@@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
@@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
-func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// handleFragments fragments pkt and calls the handler function on each
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
-// original packet. The mtu is the maximum size of the packets. The transport
-// header protocol number is required to avoid parsing the IPv6 extension
-// headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
- if fragMTU < pkt.TransportHeader().View().Size() {
+// original packet. The transport header protocol number is required to avoid
+// parsing the IPv6 extension headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // maximum payload length because they should only be transmitted once.
+ fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7
+ if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit {
+ // We need at least 8 bytes of space left for the fragmentable part because
+ // the fragment payload must obviously be non-zero and must be a multiple
+ // of 8 as per RFC 8200 section 4.5:
+ // Each complete fragment, except possibly the last ("rightmost") one, is
+ // an integer multiple of 8 octets long.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
return 0, 1, tcpip.ErrMessageTooLong
}
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
- networkHeader := header.IPv6(pkt.NetworkHeader().View())
var n int
for {
@@ -448,28 +465,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
- // The inbound path expects an unparsed packet.
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
-
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -499,13 +528,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
- if e.packetMustBeFragmented(pb, gso) {
+
+ networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ if packetMustBeFragmented(pb, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
@@ -546,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -569,7 +607,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
@@ -607,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
// As per RFC 4291 section 2.7:
// Multicast addresses must not be used as source addresses in IPv6
// packets or appear in any Routing header.
- if header.IsV6MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if header.IsV6MulticastAddress(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -641,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -651,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -663,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: previousHeaderStart,
}, pkt)
@@ -675,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -689,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -702,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -727,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
if extHdr.SegmentsLeft() != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -747,6 +790,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
continue
}
+ fragmentFieldOffset := it.ParseOffset()
+
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
@@ -762,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
it, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -790,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch lastHdr.(type) {
case header.IPv6RawPayloadHeader:
default:
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
}
@@ -799,30 +844,70 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length of a fragment, as derived from the fragment packet's
+ // Payload Length field, is not a multiple of 8 octets and the M flag
+ // of that fragment is 1, then that fragment must be discarded and an
+ // ICMP Parameter Problem, Code 0, message should be sent to the source
+ // of the fragment, pointing to the Payload Length field of the
+ // fragment packet.
+ if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6PayloadLenOffset,
+ }, pkt)
return
}
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- // Drop the fragment if the size of the reassembled payload would exceed
- // the maximum payload size.
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length and offset of a fragment are such that the Payload
+ // Length of the packet reassembled from that fragment would exceed
+ // 65,535 octets, then that fragment must be discarded and an ICMP
+ // Parameter Problem, Code 0, message should be sent to the source of
+ // the fragment, pointing to the Fragment Offset field of the fragment
+ // packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: fragmentFieldOffset,
+ }, pkt)
return
}
+ // Set up a callback in case we need to send a Time Exceeded Message as
+ // per RFC 2460 Section 4.5.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+ }
+ }
+
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
- Source: h.SourceAddress(),
- Destination: h.DestinationAddress(),
+ Source: srcAddr,
+ Destination: dstAddr,
ID: extHdr.ID(),
},
start,
@@ -830,10 +915,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.More(),
uint8(rawPayload.Identifier),
rawPayload.Buf,
+ releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
pkt.Data = data
@@ -852,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -866,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -879,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -902,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
- r.Stats().IP.PacketsDelivered.Increment()
+ stats.IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt, hasFragmentHeader)
+ e.handleICMP(pkt, hasFragmentHeader)
} else {
- r.Stats().IP.PacketsDelivered.Increment()
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ stats.IP.PacketsDelivered.Increment()
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC 4443 section 3.1:
@@ -916,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -937,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -947,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
- r.Stats().UnknownProtocolRcvdPackets.Increment()
+ stats.UnknownProtocolRcvdPackets.Increment()
return
}
}
@@ -1427,14 +1513,31 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- mtu -= header.IPv6MinimumSize
- if mtu <= maxPayloadSize {
- return mtu
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload MTU and the length of every IPv6 header.
+// Note that this is different than the Payload Length field of the IPv6 header,
+// which includes the length of the extension headers.
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv6MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // As per RFC 7112 section 5, we should discard packets if their IPv6 header
+ // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not
+ // support PMTU discovery:
+ // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain
+ // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280
+ // bytes ensures that the header chain length does not exceed the IPv6
+ // minimum MTU.
+ if networkHeadersLen > header.IPv6MinimumMTU {
+ return 0, tcpip.ErrMalformedHeader
+ }
+
+ networkMTU := linkMTU - uint32(networkHeadersLen)
+ if networkMTU > maxPayloadSize {
+ networkMTU = maxPayloadSize
}
- return maxPayloadSize
+ return networkMTU, nil
}
// Options holds options to configure a new protocol.
@@ -1488,7 +1591,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
ids: ids,
hashIV: hashIV,
@@ -1509,23 +1612,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
- // supported for outbound packets, their length should not affect the fragment
- // MTU because they should only be transmitted once.
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- mtu -= header.IPv6FragmentHeaderSize
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- if mtu <= maxPayloadSize {
- return mtu
- }
- return maxPayloadSize
-}
-
func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
}
@@ -1560,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
originalIPHeadersLength := len(originalIPHeaders)
fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 297868f24..1bfcdde25 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
@@ -238,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
@@ -271,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -825,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1844,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1912,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- nicID = 1
- fragmentExtHdrLen = 8
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_INVALID_IPV6_FRAGMENTS"
)
- payloadGen := func(payloadLen int) []byte {
- payload := make([]byte, payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = 0x30
- }
- return payload
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
}
tests := []struct {
@@ -1929,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
+ expectICMP bool
+ expectICMPType header.ICMPv6Type
+ expectICMPCode header.ICMPv6Code
+ expectICMPTypeSpecific uint32
}{
{
+ name: "fragment size is not a multiple of 8 and the M flag is true",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0 >> 3,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:9],
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
+ },
+ {
name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
fragments: []fragmentData{
{
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
- []buffer.View{
- // Fragment extension header.
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
- ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
- ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
- 0, 0, 0, 1}),
- // Payload length = 16
- payloadGen(16),
- },
- ),
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
},
},
wantMalformedIPPackets: 1,
wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
},
}
@@ -1964,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) {
NewProtocol,
},
})
- e := channel.New(0, 1500, linkAddr1)
+ e := channel.New(1, 1500, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+ var expectICMPPayload buffer.View
for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
- // Serialize IPv6 fixed header.
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
- })
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
- vv.Append(f.data)
+ vv.AppendView(f.payload)
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- }))
+ })
+
+ if test.expectICMP {
+ expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
}
if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
@@ -1999,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
}
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(test.expectICMPType),
+ checker.ICMPv6Code(test.expectICMPCode),
+ checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
+ checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ ),
+ )
+ })
+ }
+}
+
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ Clock: clock,
+ })
+
+ e := channel.New(1, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
+ checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ ),
+ )
})
}
}
@@ -2035,13 +2360,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
@@ -2056,17 +2378,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
@@ -2230,8 +2549,8 @@ var fragmentationTests = []struct {
wantFragments []fragmentInfo
}{
{
- description: "No Fragmentation",
- mtu: 1280,
+ description: "No fragmentation",
+ mtu: header.IPv6MinimumMTU,
gso: nil,
transHdrLen: 0,
payloadSize: 1000,
@@ -2241,7 +2560,18 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv6MinimumMTU + 1,
gso: nil,
transHdrLen: 0,
payloadSize: 2000,
@@ -2262,7 +2592,7 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented with gso none",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
gso: &stack.GSO{Type: stack.GSONone},
transHdrLen: 0,
payloadSize: 1400,
@@ -2273,7 +2603,7 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented with big header",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
gso: nil,
transHdrLen: 100,
payloadSize: 1200,
@@ -2448,8 +2778,8 @@ func TestFragmentationErrors(t *testing.T) {
wantError: tcpip.ErrAborted,
},
{
- description: "Error on packet with MTU smaller than transport header",
- mtu: 1280,
+ description: "Error when MTU is smaller than transport header",
+ mtu: header.IPv6MinimumMTU,
transHdrLen: 1500,
payloadSize: 500,
allowPackets: 0,
@@ -2457,6 +2787,16 @@ func TestFragmentationErrors(t *testing.T) {
mockError: nil,
wantError: tcpip.ErrMessageTooLong,
},
+ {
+ description: "Error when MTU is smaller than IPv6 minimum MTU",
+ mtu: header.IPv6MinimumMTU - 1,
+ transHdrLen: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
+ },
}
for _, ft := range tests {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index ac20f217e..981d1371a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -341,7 +341,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
if diff := cmp.Diff(existing, n); diff != "" {
t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -368,7 +368,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
}
if ok {
- t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
}
}
})
@@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
}
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
@@ -913,13 +920,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
if diff := cmp.Diff(existing, n); diff != "" {
t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
}
neighborByAddr[n.Addr] = n
}
if neigh, ok := neighborByAddr[lladdr1]; ok {
- t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
}
if test.isValid {
@@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) {
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- ep.HandlePacket(r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
var tllData [header.NDPLinkLayerAddressSize]byte
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 4d3acab96..9478f3fb7 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
+ // Cache the subnet in addrState to avoid calls to addr.Subnet() as that
+ // results in allocations on every call.
+ subnet: addr.Subnet(),
}
a.mu.endpoints[addr.Address] = addrState
addrState.mu.Lock()
@@ -361,6 +364,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
return tcpip.ErrInvalidEndpointState
}
+ a.mu.Lock()
+ defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState)
}
@@ -664,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil)
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
-
+ subnet tcpip.Subnet
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
@@ -684,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
+// Subnet implements AddressEndpoint.
+func (a *addressState) Subnet() tcpip.Subnet {
+ return a.subnet
+}
+
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 0cd1da11f..9a17efcba 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.Addr
- replyTID.srcPort = rt.Port
+ replyTID.srcAddr = address
+ replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
@@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ length := uint16(len(tcpHeader) + pkt.Data.Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
- } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index cf042309e..7a501acdc 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 8d6d9a7f1..2d8c883cd 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,30 +22,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// tableID is an index into IPTables.tables.
-type tableID int
+// TableID identifies a specific table.
+type TableID int
+// Each value identifies a specific table.
const (
- natID tableID = iota
- mangleID
- filterID
- numTables
+ NATID TableID = iota
+ MangleID
+ FilterID
+ NumTables
)
-// Table names.
-const (
- NATTable = "nat"
- MangleTable = "mangle"
- FilterTable = "filter"
-)
-
-// nameToID is immutable.
-var nameToID = map[string]tableID{
- NATTable: natID,
- MangleTable: mangleID,
- FilterTable: filterID,
-}
-
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- v4Tables: [numTables]Table{
- natID: Table{
+ v4Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -81,7 +68,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -99,7 +86,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -122,8 +109,8 @@ func DefaultTables() *IPTables {
},
},
},
- v6Tables: [numTables]Table{
- natID: Table{
+ v6Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -146,7 +133,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -164,7 +151,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -187,10 +174,10 @@ func DefaultTables() *IPTables {
},
},
},
- priorities: [NumHooks][]tableID{
- Prerouting: []tableID{mangleID, natID},
- Input: []tableID{natID, filterID},
- Output: []tableID{mangleID, natID, filterID},
+ priorities: [NumHooks][]TableID{
+ Prerouting: []TableID{MangleID, NATID},
+ Input: []TableID{NATID, FilterID},
+ Output: []TableID{MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -229,26 +216,20 @@ func EmptyNATTable() Table {
}
}
-// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- id, ok := nameToID[name]
- if !ok {
- return Table{}, false
- }
+// GetTable returns a table with the given id and IP version. It panics when an
+// invalid id is provided.
+func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
if ipv6 {
- return it.v6Tables[id], true
+ return it.v6Tables[id]
}
- return it.v4Tables[id], true
+ return it.v4Tables[id]
}
-// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- id, ok := nameToID[name]
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
+// ReplaceTable replaces or inserts table by name. It panics when an invalid id
+// is provided.
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
// check the NAT table.
- if tableID == natID && pkt.NatDone {
+ if tableID == NATID && pkt.NatDone {
continue
}
var table Table
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 538c4625d..d63e9757c 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -26,13 +28,6 @@ type AcceptTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (at *AcceptTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: at.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
@@ -44,22 +39,11 @@ type DropTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (dt *DropTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: dt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
-// ErrorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const ErrorTargetName = "ERROR"
-
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
type ErrorTarget struct {
@@ -67,14 +51,6 @@ type ErrorTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (et *ErrorTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: et.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
@@ -90,14 +66,6 @@ type UserChainTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (uc *UserChainTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: uc.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
@@ -110,50 +78,39 @@ type ReturnTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *ReturnTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
-// RedirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const RedirectTargetName = "REDIRECT"
-
-// RedirectTarget redirects the packet by modifying the destination port/IP.
+// RedirectTarget redirects the packet to this machine by modifying the
+// destination port/IP. Outgoing packets are redirected to the loopback device,
+// and incoming packets are redirected to the incoming interface (rather than
+// forwarded).
+//
// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
// them.
type RedirectTarget struct {
- // Addr indicates address used to redirect.
- Addr tcpip.Address
-
- // Port indicates port used to redirect.
+ // Port indicates port used to redirect. It is immutable.
Port uint16
- // NetworkProtocol is the network protocol the target is used with.
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *RedirectTarget) ID() TargetID {
- return TargetID{
- Name: RedirectTargetName,
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ address = tcpip.Address([]byte{127, 0, 0, 1})
} else {
- rt.Addr = header.IPv6Loopback
+ address = header.IPv6Loopback
}
case Prerouting:
- rt.Addr = address
+ // No-op, as address is already set correctly.
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
+ netHeader := pkt.Network()
+ netHeader.SetDestinationAddress(address)
// Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(protocol, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udpHeader.SetChecksum(0)
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- pkt.Network().SetDestinationAddress(rt.Addr)
-
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 7b3f3e88b..4b86c1be9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -37,7 +37,6 @@ import (
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
-// These values correspond to values in include/uapi/linux/netfilter.h.
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
@@ -86,8 +85,8 @@ type IPTables struct {
mu sync.RWMutex
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables. mu must be locked for accessing.
- v4Tables [numTables]Table
- v6Tables [numTables]Table
+ v4Tables [NumTables]Table
+ v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
@@ -96,7 +95,7 @@ type IPTables struct {
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
- priorities [NumHooks][]tableID
+ priorities [NumHooks][]TableID
connections ConnTrack
@@ -104,6 +103,24 @@ type IPTables struct {
reaperDone chan struct{}
}
+// VisitTargets traverses all the targets of all tables and replaces each with
+// transform(target).
+func (it *IPTables) VisitTargets(transform func(Target) Target) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ for tid := range it.v4Tables {
+ for i, rule := range it.v4Tables[tid].Rules {
+ it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+ for tid := range it.v6Tables {
+ for i, rule := range it.v6Tables[tid].Rules {
+ it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+}
+
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
@@ -169,7 +186,6 @@ type IPHeaderFilter struct {
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
- // TODO(gvisor.dev/issue/3549): Check this field during matching.
CheckProtocol bool
// Dst matches the destination IP address.
@@ -309,23 +325,8 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
-// A TargetID uniquely identifies a target.
-type TargetID struct {
- // Name is the target name as stored in the xt_entry_target struct.
- Name string
-
- // NetworkProtocol is the protocol to which the target applies.
- NetworkProtocol tcpip.NetworkProtocolNumber
-
- // Revision is the version of the target.
- Revision uint8
-}
-
// A Target is the interface for taking an action for a packet.
type Target interface {
- // ID uniquely identifies the Target.
- ID() TargetID
-
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 6f73a0ce4..c9b13cd0e 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
return addr, nil, nil
@@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
@@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 33806340e..d2e37f38d 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -49,8 +49,8 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
- time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 4df288798..177bf5516 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -16,7 +16,6 @@ package stack
import (
"fmt"
- "time"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -68,7 +67,7 @@ var _ NUDHandler = (*neighborCache)(nil)
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
@@ -84,7 +83,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
- entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes)
if n.dynamic.count == neighborCacheSize {
e := n.dynamic.lru.Back()
e.mu.Lock()
@@ -111,28 +110,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// provided, it will be notified when address resolution is complete (success
// or not).
//
+// If specified, the local address must be an address local to the interface the
+// neighbor cache belongs to. The local address is the source address of a
+// packet prompting NUD/link address resolution.
+//
// If address resolution is required, ErrNoLinkAddress and a notification
// channel is returned for the top level caller to block. Channel is closed
// once address resolution is complete (success or not).
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
+ Addr: remoteAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: 0,
}
return e, nil, nil
}
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
case Stale:
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
@@ -152,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
@@ -207,7 +209,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
} else {
// Static entry found with the same address but different link address.
entry.neigh.LinkAddr = linkAddr
- entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
@@ -220,8 +222,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
- n.cache[addr] = entry
+ n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
// removeEntryLocked removes the specified entry from the neighbor cache.
@@ -292,8 +293,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
// by the caller.
-func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index fcd54ed83..ed33418f3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -61,23 +61,20 @@ const (
)
// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
-// entries. The UpdatedAt field is ignored due to a lack of a deterministic
-// method to predict the time that an event will be dispatched.
+// entries. The UpdatedAtNanos field is ignored due to a lack of a
+// deterministic method to predict the time that an event will be dispatched.
func entryDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
// sort slices of entries for cases where ordering must be ignored.
func entryDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b NeighborEntry) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }))
}
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
@@ -128,9 +125,8 @@ func newTestEntryStore() *testEntryStore {
linkAddr := toLinkAddress(i)
store.entriesMap[addr] = NeighborEntry{
- Addr: addr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: linkAddr,
+ Addr: addr,
+ LinkAddr: linkAddr,
}
}
return store
@@ -195,10 +191,10 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(addr)
+ r.fakeRequest(targetAddr)
})
// Execute post address resolution action, if available.
@@ -294,9 +290,8 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -305,15 +300,19 @@ func TestNeighborCacheEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -324,8 +323,8 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,9 +353,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -365,15 +364,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -391,9 +394,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -404,8 +409,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -452,8 +457,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -470,23 +475,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
})
}
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
}, testEntryEventInfo{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
})
c.nudDisp.mu.Lock()
@@ -508,10 +519,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -564,24 +574,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -600,9 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -640,9 +655,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -682,9 +699,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -703,9 +722,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -740,9 +761,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -760,9 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -800,24 +825,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -836,16 +864,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -861,10 +893,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
},
},
}
@@ -896,12 +927,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
clock.Advance(typicalLatency)
@@ -913,7 +944,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
id, ok := s.Fetch(false /* block */)
if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
}
if id != wakerID {
t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
@@ -923,15 +954,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -964,12 +999,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
// Remove the waker before the neighbor cache has the opportunity to send a
@@ -991,15 +1026,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1028,10 +1067,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
@@ -1041,9 +1079,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1058,10 +1098,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
},
},
}
@@ -1089,9 +1128,8 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1099,15 +1137,19 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1126,9 +1168,11 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1149,16 +1193,20 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1185,24 +1233,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1220,9 +1271,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1274,29 +1327,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1312,9 +1369,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
for i := neighborCacheSize; i < store.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
- _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1322,15 +1378,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1342,22 +1398,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1374,10 +1436,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
- Addr: frequentlyUsedEntry.Addr,
- LocalAddr: frequentlyUsedEntry.LocalAddr,
- LinkAddr: frequentlyUsedEntry.LinkAddr,
- State: Reachable,
+ Addr: frequentlyUsedEntry.Addr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
},
}
@@ -1387,10 +1448,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1430,9 +1490,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
@@ -1456,10 +1515,9 @@ func TestNeighborCacheConcurrent(t *testing.T) {
t.Errorf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1488,37 +1546,36 @@ func TestNeighborCacheReplace(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
}
if t.Failed() {
t.FailNow()
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1542,37 +1599,35 @@ func TestNeighborCacheReplace(t *testing.T) {
//
// Verify the entry's new link address and the new state.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
// Verify that the neighbor is now reachable.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1601,35 +1656,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
- got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
// Verify that address resolution for an unknown address returns ErrNoLinkAddress
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1659,13 +1713,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
}
@@ -1683,18 +1737,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
delay: typicalLatency,
}
- got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
- Addr: testEntryBroadcastAddr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: testEntryBroadcastLinkAddr,
- State: Static,
+ Addr: testEntryBroadcastAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1719,9 +1772,9 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh != nil {
<-doneCh
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index be61a21af..493e48031 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -24,13 +24,18 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+const (
+ // immediateDuration is a duration of zero for scheduling work that needs to
+ // be done immediately but asynchronously to avoid deadlock.
+ immediateDuration time.Duration = 0
+)
+
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
- Addr tcpip.Address
- LocalAddr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAtNanos int64
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
@@ -106,35 +111,35 @@ type neighborEntry struct {
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
-func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
return &neighborEntry{
nic: nic,
linkRes: linkRes,
nudState: nudState,
neigh: NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- State: Unknown,
+ Addr: remoteAddr,
+ State: Unknown,
},
}
}
-// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
-// state. The entry can only transition out of Static by directly calling
-// `setStateLocked`.
+// newStaticNeighborEntry creates a neighbor cache entry starting at the
+// Static state. The entry can only transition out of Static by directly
+// calling `setStateLocked`.
func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ entry := NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: nic.stack.clock.NowNanoseconds(),
+ }
if nic.stack.nudDisp != nil {
- nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, entry)
}
return &neighborEntry{
nic: nic,
nudState: state,
- neigh: NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
- },
+ neigh: entry,
}
}
@@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
-func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
-func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
}
}
@@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
// has been removed.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
@@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.neigh.State
e.neigh.State = next
- e.neigh.UpdatedAt = time.Now()
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
config := e.nudState.Config()
switch next {
case Incomplete:
- var retryCounter uint32
- var sendMulticastProbe func()
-
- sendMulticastProbe = func() {
- if retryCounter == config.MaxMulticastProbes {
- // "If no Neighbor Advertisement is received after
- // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
- // The sender MUST return ICMP destination unreachable indications with
- // code 3 (Address Unreachable) for each packet queued awaiting address
- // resolution." - RFC 4861 section 7.2.2
- //
- // There is no need to send an ICMP destination unreachable indication
- // since the failure to resolve the address is expected to only occur
- // on this node. Thus, redirecting traffic is currently not supported.
- //
- // "If the error occurs on a node other than the node originating the
- // packet, an ICMP error message is generated. If the error occurs on
- // the originating node, an implementation is not required to actually
- // create and send an ICMP error packet to the source, as long as the
- // upper-layer sender is notified through an appropriate mechanism
- // (e.g. return value from a procedure call). Note, however, that an
- // implementation may find it convenient in some cases to return errors
- // to the sender by taking the offending packet, generating an ICMP
- // error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
- // There is no need to log the error here; the NUD implementation may
- // assume a working link. A valid link should be the responsibility of
- // the NIC/stack.LinkEndpoint.
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
-
- sendMulticastProbe()
+ panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
case Reachable:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(e.nudState.ReachableTime())
case Delay:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Probe)
e.setStateLocked(Probe)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(config.DelayFirstProbeTime)
@@ -277,24 +238,23 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
}
retryCounter++
- if retryCounter == config.MaxUnicastProbes {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(config.RetransmitTimer)
}
- sendUnicastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(immediateDuration)
case Failed:
e.notifyWakersLocked()
@@ -315,15 +275,77 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
-func (e *neighborEntry) handlePacketQueuedLocked() {
+func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
case Unknown:
- e.dispatchAddEventLocked(Incomplete)
- e.setStateLocked(Incomplete)
+ e.neigh.State = Incomplete
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+
+ e.dispatchAddEventLocked()
+
+ config := e.nudState.Config()
+
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ // As per RFC 4861 section 7.2.2:
+ //
+ // If the source address of the packet prompting the solicitation is the
+ // same as one of the addresses assigned to the outgoing interface, that
+ // address SHOULD be placed in the IP Source Address of the outgoing
+ // solicitation.
+ //
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(immediateDuration)
case Stale:
- e.dispatchChangeEventLocked(Delay)
e.setStateLocked(Delay)
+ e.dispatchChangeEventLocked()
case Incomplete, Reachable, Delay, Probe, Static, Failed:
// Do nothing
@@ -345,21 +367,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
switch e.neigh.State {
case Unknown, Incomplete, Failed:
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchAddEventLocked(Stale)
e.setStateLocked(Stale)
e.notifyWakersLocked()
+ e.dispatchAddEventLocked()
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Stale:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Static:
@@ -393,12 +415,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.neigh.LinkAddr = linkAddr
if flags.Solicited {
- e.dispatchChangeEventLocked(Reachable)
e.setStateLocked(Reachable)
} else {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
}
+ e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
e.notifyWakersLocked()
@@ -411,8 +432,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
break
}
@@ -421,23 +442,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if !flags.Solicited {
if e.neigh.State != Stale {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
- e.dispatchChangeEventLocked(e.neigh.State)
+ e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- }
+ wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyWakersLocked()
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
}
if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
@@ -475,11 +497,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- // Set state to Reachable again to refresh timers.
- }
+ wasReachable := e.neigh.State == Reachable
+ // Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
case Unknown, Incomplete, Failed, Static:
// Do nothing
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 3ee2a3b31..c2b763325 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -47,24 +47,27 @@ const (
entryTestNetDefaultMTU = 65536
)
+// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
+ clock.Advance(immediateDuration)
+}
+
// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
-// The UpdatedAt field is ignored due to a lack of a deterministic method to
-// predict the time that an event will be dispatched.
+// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
+// to predict the time that an event will be dispatched.
func eventDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
// sort slices of events for cases where ordering must be ignored.
func eventDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
+ }))
}
// The following unit tests exercise every state transition and verify its
@@ -125,14 +128,11 @@ func (t testEntryEventType) String() string {
type testEntryEventInfo struct {
EventType testEntryEventType
NICID tcpip.NICID
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Entry NeighborEntry
}
func (e testEntryEventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry)
}
// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
@@ -150,36 +150,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
d.events = append(d.events, e)
}
-func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestAdded,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestChanged,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestRemoved,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
@@ -202,9 +193,9 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
p := entryTestProbeInfo{
- RemoteAddress: addr,
+ RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
LocalAddress: localAddr,
}
@@ -245,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes)
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
@@ -323,15 +314,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
func TestEntryUnknownToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -350,9 +342,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
{
@@ -367,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
func TestEntryUnknownToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
@@ -377,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Unlock()
// No probes should have been sent.
+ runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
linkRes.mu.Unlock()
@@ -388,9 +383,11 @@ func TestEntryUnknownToStale(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -406,11 +403,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
- updatedAt := e.neigh.UpdatedAt
+ updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
clock.Advance(c.RetransmitTimer)
@@ -437,7 +434,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want {
t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -468,16 +465,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -487,7 +488,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant {
t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
}
e.mu.Unlock()
@@ -495,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
func TestEntryIncompleteToReachable(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -526,20 +520,35 @@ func TestEntryIncompleteToReachable(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -555,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
// to Reachable.
func TestEntryAddsAndClearsWakers(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
w := sleep.Waker{}
s := sleep.Sleeper{}
@@ -563,7 +572,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
defer s.Done()
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -587,34 +614,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -626,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -659,20 +666,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
}
linkRes.mu.Unlock()
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -684,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
func TestEntryIncompleteToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -715,20 +733,35 @@ func TestEntryIncompleteToStale(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -744,7 +777,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -783,16 +816,20 @@ func TestEntryIncompleteToFailed(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -817,12 +854,30 @@ func (*testLocker) Unlock() {}
func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -848,34 +903,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -893,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -928,20 +959,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -961,17 +1014,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -986,29 +1032,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1026,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1058,27 +1110,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1086,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1132,27 +1184,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1160,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1206,27 +1262,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1234,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1279,20 +1340,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1304,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1343,27 +1408,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1375,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
e.mu.Lock()
- e.handlePacketQueuedLocked()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1400,41 +1511,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1446,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1485,27 +1570,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1517,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1552,27 +1651,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1584,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
func TestEntryStaleToDelay(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1616,27 +1728,48 @@ func TestEntryStaleToDelay(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1656,22 +1789,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleUpperLevelConfirmationLocked()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1686,43 +1807,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1743,29 +1889,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1780,43 +1907,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1837,13 +1996,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if e.neigh.State != Delay {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
@@ -1860,57 +2037,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1922,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1962,27 +2115,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1994,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2027,34 +2197,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2066,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2103,34 +2281,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2145,69 +2351,91 @@ func TestEntryDelayToProbe(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2228,36 +2456,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2274,37 +2516,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2312,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
@@ -2325,36 +2571,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2375,37 +2635,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2413,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
@@ -2426,36 +2690,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2479,30 +2758,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2529,17 +2816,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
wantProbes := []entryTestProbeInfo{
- // Probe caused by the Delay-to-Probe transition
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2567,42 +2851,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2622,36 +2915,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2672,49 +2979,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2734,36 +3052,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2781,49 +3113,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2843,36 +3186,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2890,49 +3247,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2946,87 +3314,116 @@ func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
c.MaxUnicastProbes = 3
+ c.DelayFirstProbeTime = c.RetransmitTimer
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.Advance(waitFor)
+ // Observe each probe sent while in the Probe state.
+ for i := uint32(0); i < c.MaxUnicastProbes; i++ {
+ clock.Advance(c.RetransmitTimer)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ }
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.mu.Unlock()
}
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+
+ // Wait for the last probe to expire, causing a transition to Failed.
+ clock.Advance(c.RetransmitTimer)
+ e.mu.Lock()
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -3034,12 +3431,6 @@ func TestEntryProbeToFailed(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryFailedGetsDeleted(t *testing.T) {
@@ -3054,84 +3445,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
}
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
clock.Advance(waitFor)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The next three probe are sent in Probe.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dcd4319bf..60c81a3aa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
return n.writePacket(r, gso, protocol, pkt)
}
+// WritePacketToRemote implements NetworkInterface.
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ r := Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return n.writePacket(&r, gso, protocol, pkt)
+}
+
func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
@@ -339,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
+func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ return true
+ }
+
+ return false
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
@@ -546,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
}
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
defer r.Release()
- r.RemoteLinkAddress = remotelinkAddr
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -585,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if local == "" {
local = n.LinkEndpoint.LinkAddress()
}
+ pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
@@ -660,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.nic
+ n := r.localAddressNIC
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
- r.RemoteLinkAddress = remote
+ pkt.NICID = n.ID()
r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -678,7 +697,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
+ // the TTL field for ipv4/ipv6.
// pkt may have set its header and may not have enough headroom for
// link-layer header for the other link to prepend. Here we create a new
@@ -725,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -737,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, pkt)
+ n.stack.demux.deliverRawPacket(protocol, pkt)
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
@@ -766,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
return TransportPacketHandled
}
- id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers()))
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+ id := TransportEndpointID{
+ LocalPort: dstPort,
+ LocalAddress: dst,
+ RemotePort: srcPort,
+ RemoteAddress: src,
+ }
+ if n.stack.demux.deliverPacket(protocol, pkt, id) {
return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, pkt) {
+ if state.defaultHandler(id, pkt) {
return TransportPacketHandled
}
}
@@ -781,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet so
// give the protocol specific error handler a chance to handle it.
// If it doesn't handle it then we should do so.
- switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
return TransportPacketHandled
@@ -885,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 97a96af62..5b5c58afb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
-}
+func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
// Close implements NetworkEndpoint.Close.
func (e *testIPv6Endpoint) Close() {
@@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index e1ec15487..ab629b3a4 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -129,7 +129,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborAdded(tcpip.NICID, NeighborEntry)
// OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
// neighbor table changes state and/or link address.
@@ -138,7 +138,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborChanged(tcpip.NICID, NeighborEntry)
// OnNeighborRemoved will be called when an entry is removed from a NIC's
// (with ID nicID) neighbor table.
@@ -147,7 +147,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborRemoved(tcpip.NICID, NeighborEntry)
}
// ReachabilityConfirmationFlags describes the flags used within a reachability
@@ -177,7 +177,7 @@ type NUDHandler interface {
// Neighbor Solicitation for ARP or NDP, respectively). Validation of the
// probe needs to be performed before calling this function since the
// Neighbor Cache doesn't have access to view the NIC's assigned addresses.
- HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+ HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7f54a6de8..664cc6fa0 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -112,6 +112,16 @@ type PacketBuffer struct {
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
+
+ // NICID is the ID of the interface the network packet was received at.
+ NICID tcpip.NICID
+
+ // RXTransportChecksumValidated indicates that transport checksum verification
+ // may be safely skipped.
+ RXTransportChecksumValidated bool
+
+ // NetworkPacketInfo holds an incoming packet's network-layer information.
+ NetworkPacketInfo NetworkPacketInfo
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// Clone should be called in such cases so that no modifications is done to
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
- TransportProtocolNumber: pk.TransportProtocolNumber,
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
+ PktType: pk.PktType,
+ NICID: pk.NICID,
+ RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
+ NetworkPacketInfo: pk.NetworkPacketInfo,
}
- return newPk
+}
+
+// SourceLinkAddress returns the source link address of the packet.
+func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
+ link := pk.LinkHeader().View()
+
+ if link.IsEmpty() {
+ return ""
+ }
+
+ return header.Ethernet(link).SourceAddress()
}
// Network returns the network header as a header.Network.
@@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
+// CloneToInbound makes a shallow copy of the packet buffer to be used as an
+// inbound packet.
+//
+// See PacketBuffer.Data for details about how a packet buffer holds an inbound
+// packet.
+func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
+ return NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
+ })
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index f838eda8d..5d364a2b0 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
} else if _, err := p.route.Resolve(nil); err != nil {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index defb9129b..b8f333057 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -63,17 +63,28 @@ const (
ControlUnknown
)
+// NetworkPacketInfo holds information about a network layer packet.
+type NetworkPacketInfo struct {
+ // RemoteAddressBroadcast is true if the packet's remote address is a
+ // broadcast address.
+ RemoteAddressBroadcast bool
+
+ // LocalAddressBroadcast is true if the packet's local address is a broadcast
+ // address.
+ LocalAddressBroadcast bool
+}
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// UniqueID returns an unique ID for this transport endpoint.
UniqueID() uint64
- // HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint. It sets pkt.TransportHeader.
+ // HandlePacket is called by the stack when new packets arrive to this
+ // transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
@@ -105,8 +116,8 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(*PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -172,9 +183,9 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // HandleUnknownDestinationPacket takes ownership of the packet if it handles
// the issue.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
+ HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -227,8 +238,8 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
+ // DeliverTransportPacket takes ownership of the packet.
+ DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface {
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
+ // Subnet returns the subnet of the endpoint's address.
+ Subnet() tcpip.Subnet
+
// IsAssigned returns whether or not the endpoint is considered bound
// to its NetworkEndpoint.
IsAssigned(allowExpired bool) bool
@@ -490,6 +504,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // WritePacketToRemote writes the packet to the given remote link address.
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -544,7 +561,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -764,13 +781,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
- // the request on the local network if remoteLinkAddr is the zero value. The
- // request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the link address of the target
+ // address. The request is broadcasted on the local network if a remote link
+ // address is not provided.
//
- // A valid response will cause the discovery protocol's network
- // endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
+ // The request is sent from the passed network interface. If the interface
+ // local address is unspecified, any interface local address may be used.
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..15ff437c7 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,11 +47,16 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // nic is the NIC the route goes through.
- nic *NIC
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
- // addressEndpoint is the local address this route is associated with.
- addressEndpoint AssignableAddressEndpoint
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
// linkCache is set if link address resolution is enabled for this protocol on
// the route's NIC.
@@ -60,51 +67,144 @@ type Route struct {
linkRes LinkAddressResolver
}
+// constructAndValidateRoute validates and initializes a route. It takes
+// ownership of the provided local address.
+//
+// Returns an empty route if validation fails.
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+ addrWithPrefix := addressEndpoint.AddressWithPrefix()
+
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ addressEndpoint.DecRef()
+ return Route{}
+ }
+
+ // If no remote address is provided, use the local address.
+ if len(remoteAddr) == 0 {
+ remoteAddr = addrWithPrefix.Address
+ }
+
+ r := makeRoute(
+ netProto,
+ addrWithPrefix.Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ addressEndpoint,
+ handleLocal,
+ multicastLoop,
+ )
+
+ // If the route requires us to send a packet through some gateway, do not
+ // broadcast it.
+ if len(gateway) > 0 {
+ r.NextHop = gateway
+ } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ }
+
+ return r
+}
+
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+ if localAddressNIC.stack != outgoingNIC.stack {
+ panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
+ }
+
loop := PacketOut
- if handleLocal && localAddr != "" && remoteAddr == localAddr {
- loop = PacketLoop
- } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
- loop |= PacketLoop
- } else if remoteAddr == header.IPv4Broadcast {
- loop |= PacketLoop
+
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if !outgoingNIC.IsLoopback() {
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ loop |= PacketLoop
+ }
}
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- addressEndpoint: addressEndpoint,
- nic: nic,
- Loop: loop,
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ localAddressEndpoint: localAddressEndpoint,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = r.outgoingNIC.stack
}
}
return r
}
+// makeLocalRoute initializes a new local route. It takes ownership of the
+// provided AssignableAddressEndpoint.
+//
+// A local route is a route to a destination that is local to the stack.
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+ loop := PacketLoop
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if outgoingNIC.IsLoopback() {
+ loop = PacketOut
+ }
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+// PopulatePacketInfo populates a packet buffer's packet information fields.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
+ if r.local() {
+ pkt.RXTransportChecksumValidated = true
+ }
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+}
+
+// networkPacketInfo returns the network packet information of the route.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) networkPacketInfo() NetworkPacketInfo {
+ return NetworkPacketInfo{
+ RemoteAddressBroadcast: r.IsOutboundBroadcast(),
+ LocalAddressBroadcast: r.isInboundBroadcast(),
+ }
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.nic.ID()
+ return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.nic.stack.Stats()
+ return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
-// Capabilities returns the link-layer capabilities of the route.
-func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+// RequiresTXTransportChecksum returns false if the route does not require
+// transport checksums to be populated.
+func (r *Route) RequiresTXTransportChecksum() bool {
+ if r.local() {
+ return false
+ }
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
+}
+
+// HasSoftwareGSOCapability returns true if the route supports software GSO.
+func (r *Route) HasSoftwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+}
+
+// HasHardwareGSOCapability returns true if the route supports hardware GSO.
+func (r *Route) HasHardwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+}
+
+// HasSaveRestoreCapability returns true if the route supports save/restore.
+func (r *Route) HasSaveRestoreCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
+}
+
+// HasDisconncetOkCapability returns true if the route supports disconnecting.
+func (r *Route) HasDisconncetOkCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ // If specified, the local address used for link address resolution must be an
+ // address on the outgoing interface.
+ var linkAddressResolutionRequestLocalAddr tcpip.Address
+ if r.localAddressNIC == r.outgoingNIC {
+ linkAddressResolutionRequestLocalAddr = r.LocalAddress
+ }
+
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
neigh.removeWaker(nextAddr, waker)
return
}
- r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
+}
+
+// local returns true if the route is a local route.
+func (r *Route) local() bool {
+ return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before the route can be written to.
//
-// The NIC r uses must not be locked.
+// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.nic.neigh != nil {
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ return false
}
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
+
+ return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil
+}
+
+func (r *Route) isValidForOutgoing() bool {
+ if !r.outgoingNIC.Enabled() {
+ return false
+ }
+
+ if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ return false
+ }
+
+ // If the source NIC and outgoing NIC are different, make sure the stack has
+ // forwarding enabled, or the packet will be handled locally.
+ if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) {
+ return false
+ }
+
+ return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.nic.getNetworkEndpoint(r.NetProto).MTU()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.addressEndpoint != nil {
- r.addressEndpoint.DecRef()
- r.addressEndpoint = nil
+ if r.localAddressEndpoint != nil {
+ r.localAddressEndpoint.DecRef()
+ r.localAddressEndpoint = nil
}
}
// Clone clones the route.
func (r *Route) Clone() Route {
- if r.addressEndpoint != nil {
- _ = r.addressEndpoint.IncRef()
+ if r.localAddressEndpoint != nil {
+ if !r.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
+ }
}
return *r
}
@@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.nic.stack
+ return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := r.localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool {
return r.isV4Broadcast(r.RemoteAddress)
}
-// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// isInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
-func (r *Route) IsInboundBroadcast() bool {
+func (r *Route) isInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.LocalAddress)
}
@@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool {
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- addressEndpoint: r.addressEndpoint,
- nic: r.nic,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ localAddressEndpoint: r.localAddressEndpoint,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3a07577c8..a23fb97ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -518,6 +519,10 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // IPTables are the initial iptables rules. If nil, iptables will allow
+ // all traffic.
+ IPTables *IPTables
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -620,6 +625,10 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
+ if opts.IPTables == nil {
+ opts.IPTables = DefaultTables()
+ }
+
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
@@ -633,7 +642,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
+ tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
@@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -830,6 +839,20 @@ func (s *Stack) AddRoute(route tcpip.Route) {
s.routeTable = append(s.routeTable, route)
}
+// RemoveRoutes removes matching routes from the route table.
+func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var filteredRoutes []tcpip.Route
+ for _, route := range s.routeTable {
+ if !match(route) {
+ filteredRoutes = append(filteredRoutes, route)
+ }
+ }
+ s.routeTable = filteredRoutes
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -1180,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return Route{}, false
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return Route{}, false
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return Route{}, false
+ }
+
+ return r, true
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
+ return r, true
+ }
+ }
+
+ return Route{}, false
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return Route{}, false
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ addressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
+ if r == (Route{}) {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
+
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ return r, nil
}
+ }
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+ }
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
return r, nil
}
}
}
}
- if !needRoute {
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return Route{}, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1323,7 +1517,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1443,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1896,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e75f58c64..dedfdd435 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
@@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
+ netHdr := pkt.NetworkHeader().View()
+ f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ pkt.NetworkProtocolNumber = fakeNetNumber
hdr[dstAddrOffset] = r.RemoteAddress[0]
hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- f.HandlePacket(r, pkt)
+ pkt := pkt.Clone()
+ r.PopulatePacketInfo(pkt)
+ f.HandlePacket(pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
if !ok {
return 0, false, false
}
+ pkt.NetworkProtocolNumber = fakeNetNumber
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) {
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
+// TestExternalSendWithHandleLocal tests that the stack creates a non-local
+// route when spoofing or promiscuous mode are enabled.
+//
+// This test makes sure that packets are transmitted from the stack.
+func TestExternalSendWithHandleLocal(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+
+ localAddr = tcpip.Address("\x01")
+ dstAddr = tcpip.Address("\x03")
+ )
+
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ configureStack func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Default",
+ configureStack: func(*testing.T, *stack.Stack) {},
+ },
+ {
+ name: "Spoofing",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ {
+ name: "Promiscuous",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, handleLocal := range []bool{true, false} {
+ t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ HandleLocal: handleLocal,
+ })
+
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
+
+ test.configureStack(t, s)
+
+ r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
+ }
+ defer r.Release()
+
+ if r.LocalAddress != localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+
+ if n := ep.Drain(); n != 0 {
+ t.Fatalf("got ep.Drain() = %d, want = 0", n)
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: fakeTransNumber,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.NewView(10).ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, _, _): %s", err)
+ }
+ if n := ep.Drain(); n != 1 {
+ t.Fatalf("got ep.Drain() = %d, want = 1", n)
+ }
+ })
+ }
+ })
+ }
+}
+
func TestSpoofingWithAddress(t *testing.T) {
localAddr := tcpip.Address("\x01")
nonExistentLocalAddr := tcpip.Address("\x02")
@@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
RemoteAddress: ipv4SubnetBcast,
RemoteLinkAddress: header.EthernetBroadcastAddress,
NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
+ Loop: stack.PacketOut | stack.PacketLoop,
},
},
// Broadcast to a locally attached /31 subnet does not populate the
@@ -3672,3 +3778,453 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
}
}
+
+// TestAddRoute tests Stack.AddRoute
+func TestAddRoute(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ subnet1, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected := []tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ }
+
+ // Initialize the route table with one route.
+ s.SetRouteTable([]tcpip.Route{expected[0]})
+
+ // Add another route.
+ s.AddRoute(expected[1])
+
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+// TestRemoveRoutes tests Stack.RemoveRoutes
+func TestRemoveRoutes(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ addressToRemove := tcpip.Address("\x01")
+ subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet3, err := tcpip.NewSubnet("\x02", "\x02")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Initialize the route table with three routes.
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ {Destination: subnet3, Gateway: "\x00", NIC: 1},
+ })
+
+ // Remove routes with the specific address.
+ s.RemoveRoutes(func(r tcpip.Route) bool {
+ return r.Destination.ID() == addressToRemove
+ })
+
+ expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}}
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+func TestFindRouteWithForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Addr = tcpip.Address("\x01")
+ nic2Addr = tcpip.Address("\x02")
+ remoteAddr = tcpip.Address("\x03")
+ )
+
+ type netCfg struct {
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1Addr tcpip.Address
+ nic2Addr tcpip.Address
+ remoteAddr tcpip.Address
+ }
+
+ fakeNetCfg := netCfg{
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1Addr: nic1Addr,
+ nic2Addr: nic2Addr,
+ remoteAddr: remoteAddr,
+ }
+
+ globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
+ globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
+
+ ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: llAddr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: globalIPv6Addr1,
+ }
+ ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: llAddr1,
+ remoteAddr: llAddr2,
+ }
+ ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }
+
+ tests := []struct {
+ name string
+
+ netCfg netCfg
+ forwardingEnabled bool
+
+ addrNIC tcpip.NICID
+ localAddr tcpip.Address
+
+ findRouteErr *tcpip.Error
+ dependentOnForwarding bool
+ }{
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ }
+
+ if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ }
+
+ if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
+
+ r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if err != test.findRouteErr {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ }
+ defer r.Release()
+
+ if test.findRouteErr != nil {
+ return
+ }
+
+ if r.LocalAddress != test.localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr)
+ }
+ if r.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Sending a packet should always go through NIC2 since we only install a
+ // route to test.netCfg.remoteAddr through NIC2.
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := send(r, data); err != nil {
+ t.Fatalf("send(_, _): %s", err)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not sent through ep2")
+ }
+ if pkt.Route.LocalAddress != test.localAddr {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ }
+ if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if !test.forwardingEnabled || !test.dependentOnForwarding {
+ return
+ }
+
+ // Disabling forwarding when the route is dependent on forwarding being
+ // enabled should make the route invalid.
+ if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ }
+ if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ if n := ep2.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep2", n)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 62ab6d92f..c457b67a2 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -28,7 +28,7 @@ import (
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
+ fakeTransHeaderLen int = 3
)
// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
@@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
+ if f.acceptQueue == nil {
+ return
}
+
+ netHdr := pkt.NetworkHeader().View()
+ route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ route.ResolveWith(pkt.SourceLinkAddress())
+
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: route.RemoteAddress,
+ route: route,
+ })
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
@@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index d77848d61..3ab2b7654 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool {
return s.Prefix() <= 30 && s.Broadcast() == address
}
-// Equal returns true if s equals o.
-//
-// Needed to use cmp.Equal on Subnet as its fields are unexported.
+// Equal returns true if this Subnet is equal to the given Subnet.
func (s Subnet) Equal(o Subnet) bool {
+ // If this changes, update Route.Equal accordingly.
return s == o
}
@@ -763,6 +762,10 @@ const (
// endpoint that all packets being written have an IP header and the
// endpoint should not attach an IP header.
IPHdrIncludedOption
+
+ // AcceptConnOption is used by GetSockOptBool to indicate if the
+ // socket is a listening socket.
+ AcceptConnOption
)
// SockOptInt represents socket options which values have the int type.
@@ -1256,6 +1259,12 @@ func (r Route) String() string {
return out.String()
}
+// Equal returns true if the given Route is equal to this Route.
+func (r Route) Equal(to Route) bool {
+ // NOTE: This relies on the fact that r.Destination == to.Destination
+ return r == to
+}
+
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -1496,6 +1505,15 @@ type IPStats struct {
// IPTablesOutputDropped is the total number of IP packets dropped in
// the Output chain.
IPTablesOutputDropped *StatCounter
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived *StatCounter
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived *StatCounter
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived *StatCounter
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 34aab32d0..9b0f3b675 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -10,6 +10,7 @@ go_test(
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
+ "route_test.go",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 0dcef7b04..bf7594268 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -33,11 +33,6 @@ import (
func TestForwarding(t *testing.T) {
const (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
- routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
host1NICID = 1
routerNICID1 = 2
routerNICID2 = 3
@@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) {
}
},
},
+ {
+ name: "IPv4 host2 server with routerNIC1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 routerNIC2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
}
for _, test := range tests {
@@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) {
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr)
- routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr)
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) {
if err == tcpip.ErrNoLinkAddress {
// Wait for link resolution to complete.
<-ch
-
n, _, err = ep.Write(dataPayload, wOpts)
- } else if err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
}
-
if err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
}
@@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
-
var addr tcpip.FullAddress
v, _, err := ep.Read(&addr)
if err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 6ddcda70c..fe7c1bb3d 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -32,32 +32,36 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+const (
+ linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+)
- host1IPv4Addr = tcpip.ProtocolAddress{
+var (
+ ipv4Addr1 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
- host2IPv4Addr = tcpip.ProtocolAddress{
+ ipv4Addr2 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 8,
},
}
- host1IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr1 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::1").To16()),
PrefixLen: 64,
},
}
- host2IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr2 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::2").To16()),
@@ -89,7 +93,7 @@ func TestPing(t *testing.T) {
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
- remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
@@ -104,7 +108,7 @@ func TestPing(t *testing.T) {
name: "IPv6 Ping",
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
- remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
@@ -127,7 +131,7 @@ func TestPing(t *testing.T) {
host1Stack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr)
+ host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -143,36 +147,36 @@ func TestPing(t *testing.T) {
t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
tcpip.Route{
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
})
host2Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
tcpip.Route{
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
})
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index e8caf09ba..421da1add 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
},
})
- wq := waiter.Queue{}
+ var wq waiter.Queue
rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index f1028823b..cdf0459e3 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- wq := waiter.Queue{}
+ var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
@@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
)
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
-
tests := []struct {
name string
broadcastAddr tcpip.Address
@@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
})
+ type endpointAndWaiter struct {
+ ep tcpip.Endpoint
+ ch chan struct{}
+ }
+ var eps []endpointAndWaiter
// We create endpoints that bind to both the wildcard address and the
// broadcast address to make sure both of these types of "broadcast
// interested" endpoints receive broadcast packets.
- wq := waiter.Queue{}
- var eps []tcpip.Endpoint
for _, bindWildcard := range []bool{false, true} {
// Create multiple endpoints for each type of "broadcast interested"
// endpoint so we can test that all endpoints receive the broadcast
// packet.
for i := 0; i < 2; i++ {
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
@@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
}
- eps = append(eps, ep)
+ eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
}
}
@@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- if n, _, err := wep.Write(data, writeOpts); err != nil {
+ data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
+ if n, _, err := wep.ep.Write(data, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
}
for j, rep := range eps {
- if gotPayload, _, err := rep.Read(nil); err != nil {
+ // Wait for the endpoint to become readable.
+ <-rep.ch
+
+ if gotPayload, _, err := rep.ep.Read(nil); err != nil {
t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
} else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
new file mode 100644
index 000000000..02fc47015
--- /dev/null
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -0,0 +1,388 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TestLocalPing tests pinging a remote that is local the stack.
+//
+// This tests that a local route is created and packets do not leave the stack.
+func TestLocalPing(t *testing.T) {
+ const (
+ nicID = 1
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
+ channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
+ channelEP := e.(*channel.Endpoint)
+ if n := channelEP.Drain(); n != 0 {
+ t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
+ }
+ }
+
+ ipv4ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ ipv6ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ linkEndpoint func() stack.LinkEndpoint
+ localAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ expectedConnectErr *tcpip.Error
+ checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
+ }{
+ {
+ name: "IPv4 loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: ipv4Loopback,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: header.IPv6Loopback,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv4Addr.Address,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv6Addr.Address,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv4 loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
+
+ if test.expectedConnectErr != nil {
+ return
+ }
+
+ payload := tcpip.SlicePayload(test.icmpBuf(t))
+ var wOpts tcpip.WriteOptions
+ if n, _, err := ep.Write(payload, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
+}
+
+// TestLocalUDP tests sending UDP packets between two endpoints that are local
+// to the stack.
+//
+// This tests that that packets never leave the stack and the addresses
+// used when sending a packet.
+func TestLocalUDP(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ tests := []struct {
+ name string
+ canBePrimaryAddr tcpip.ProtocolAddress
+ firstPrimaryAddr tcpip.ProtocolAddress
+ }{
+ {
+ name: "IPv4",
+ canBePrimaryAddr: ipv4Addr1,
+ firstPrimaryAddr: ipv4Addr2,
+ },
+ {
+ name: "IPv6",
+ canBePrimaryAddr: ipv6Addr1,
+ firstPrimaryAddr: ipv6Addr2,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ addAddress bool
+ expectedWriteErr *tcpip.Error
+ }{
+ {
+ name: "Unassigned local address",
+ addAddress: false,
+ expectedWriteErr: tcpip.ErrNoRoute,
+ },
+ {
+ name: "Assigned local address",
+ addAddress: true,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ HandleLocal: true,
+ }
+
+ s := stack.New(stackOpts)
+ ep := channel.New(1, header.IPv6MinimumMTU, "")
+
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if subTest.addAddress {
+ if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ }
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer server.Close()
+
+ bindAddr := tcpip.FullAddress{Port: 80}
+ if err := server.Bind(bindAddr); err != nil {
+ t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventIn)
+ client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer client.Close()
+
+ serverAddr := tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ Port: 80,
+ }
+
+ clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &serverAddr,
+ }
+ if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
+ } else if subTest.expectedWriteErr != nil {
+ // Nothing else to test if we expected not to be able to send the
+ // UDP packet.
+ return
+ } else if n != int64(len(clientPayload)) {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload))
+ }
+ }
+
+ // Wait for the server endpoint to become readable.
+ <-serverCH
+
+ var clientAddr tcpip.FullAddress
+ if v, _, err := server.Read(&clientAddr); err != nil {
+ t.Fatalf("server.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ }
+ if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
+ t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &clientAddr,
+ }
+ if n, _, err := server.Write(serverPayload, wOpts); err != nil {
+ t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
+ } else if n != int64(len(serverPayload)) {
+ t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload))
+ }
+ }
+
+ // Wait for the client endpoint to become readable.
+ <-clientCH
+
+ var gotServerAddr tcpip.FullAddress
+ if v, _, err := client.Read(&gotServerAddr); err != nil {
+ t.Fatalf("client.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ }
+ if gotServerAddr.Addr != serverAddr.Addr {
+ t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 41eb0ca44..763cd8f84 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
default:
@@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
},
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 87d510f96..3820e5dc7 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 072601d2d..31831a6d8 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.AcceptConnOption:
+ return false, nil
+ default:
+ return false, tcpip.ErrNotSupported
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e37c00523..7b6a87ba9 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.IPHdrIncludedOption:
@@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ remoteAddr := pkt.Network().SourceAddress()
+
if e.bound {
// If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
return
}
// If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// If connected, only accept packets from the remote address we
// connected to.
- if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// Push new packet into receive list and increment the buffer size.
packet := &rawPacket{
senderAddr: tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress,
+ NIC: pkt.NICID,
+ Addr: remoteAddr,
},
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 33bfb56cd..7d97cbdc7 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
}
// beforeSave is invoked by stateify.
-func (ep *endpoint) beforeSave() {
+func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
// The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
// packets.
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
}
// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
// Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
+ e.rcvBufSizeMax = 0
// Release the lock acquired in beforeSave() so regular endpoint closing
// logic can proceed after save.
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return max
}
// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
}
// afterLoad is invoked by stateify.
-func (ep *endpoint) afterLoad() {
- stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
-func (ep *endpoint) Resume(s *stack.Stack) {
- ep.stack = s
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
// If the endpoint is connected, re-connect.
- if ep.connected {
+ if e.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
if err != nil {
panic(err)
}
}
// If the endpoint is bound, re-bind.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
+ if e.bound {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if ep.associated {
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ if e.associated {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b706438bd..47982ca41 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
- netProto = s.route.NetProto
+ netProto = s.netProto
}
+
+ route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(s.remoteLinkAddr)
+
n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6Only
n.ID = s.id
- n.boundNICID = s.route.NICID()
- n.route = s.route.Clone()
- n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.boundNICID = s.nicID
+ n.route = route
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- return n
+ return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ if err != nil {
+ return nil, err
+ }
// Lock the endpoint before registering to ensure that no out of
// band changes are possible due to incoming packets etc till
@@ -425,20 +435,17 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer ctx.synRcvdCount.dec()
- defer func() {
- e.mu.Lock()
- e.decSynRcvdCount()
- e.mu.Unlock()
- }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
+ e.decSynRcvdCount()
return
}
ctx.removePendingEndpoint(n)
+ e.decSynRcvdCount()
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
@@ -456,7 +463,9 @@ func (e *endpoint) incSynRcvdCount() bool {
}
func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
e.synRcvdCount--
+ e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
@@ -468,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -478,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
switch {
@@ -493,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
- return
+ return nil
}
ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
} else {
// If cookies are in use but the endpoint accept queue
// is full then drop the syn.
@@ -507,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Send SYN without window scaling because we currently
// don't encode this information in the cookie.
//
@@ -524,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TS: opts.TS,
TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
}
- e.sendSynTCP(&s.route, tcpFields{
+ fields := tcpFields{
id: s.id,
ttl: e.ttl,
tos: e.sendTOS,
@@ -534,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
- }, synOpts)
+ }
+ if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ return err
+ }
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
}
case (s.flags & header.TCPFlagAck) != 0:
@@ -548,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
if !ctx.synRcvdCount.synCookiesInUse() {
@@ -567,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
iss := s.ackNumber - 1
@@ -588,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -609,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ if err != nil {
+ return err
+ }
n.mu.Lock()
@@ -623,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return nil
}
// Register new endpoint so that packets are routed to it.
@@ -633,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return err
}
n.isRegistered = true
@@ -671,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
+ return nil
+
+ default:
+ return nil
}
}
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
-func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.v6only
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
@@ -715,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
- return nil
+ return
}
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
s := e.segmentQueue.dequeue()
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
close(e.drainDone)
@@ -739,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
break
}
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 0aaef495d..2facbebec 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
MSS: amss,
}
if ttl == 0 {
- ttl = s.route.DefaultTTL()
+ ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
// Wait for notification.
- index, _ = s.Fetch(true)
+ h.ep.mu.Unlock()
+ index, _ = s.Fetch(true /* block */)
+ h.ep.mu.Lock()
}
}
@@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error {
}, synOpts)
for h.state != handshakeCompleted {
+ // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
+ // throughout handshake processing).
h.ep.mu.Unlock()
- index, _ := s.Fetch(true)
+ index, _ := s.Fetch(true /* block */)
h.ep.mu.Lock()
switch index {
@@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// TCP header, then the kernel calculate a checksum of the
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
- } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ } else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
@@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
if ep == nil {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
s.decRef()
return
}
@@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
drained := e.drainDone != nil
if drained {
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
// Set up the functions that will be called when the main protocol loop
@@ -1535,7 +1541,7 @@ loop:
}
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
// We need to double check here because the notification may be
@@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
}
for _, netProto := range netProtos {
- if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil {
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
@@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
for {
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
switch v {
case newSegment:
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 98aecab9e..21162f01a 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -172,10 +172,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newSegment(r, id, pkt)
- if !s.parse() {
+
+ s := newIncomingSegment(id, pkt)
+ if !s.parse(pkt.RXTransportChecksumValidated) {
ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 560b4904c..a6f25896b 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
testV6Connect(t, c)
}
+func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.NewWithOpts(t, context.Options{
+ EnableV6: true,
+ MTU: defaultMTU,
+ })
+ defer c.Cleanup()
+
+ // Create a v6 endpoint but don't set the v6-only TCP option.
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3bcd3923a..258f9f1bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -721,9 +721,9 @@ func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
// by another user goroutine if not then we spin, otherwise
- // we just goto sleep on the Lock() and wait.
+ // we just go to sleep on the Lock() and wait.
if !e.mu.TryLock() {
- // If socket is owned by the user then just goto sleep
+ // If socket is owned by the user then just go to sleep
// as the lock could be held for a reasonably long time.
if atomic.LoadUint32(&e.ownedByUser) == 1 {
e.mu.Lock()
@@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
+ s := newOutgoingSegment(e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
@@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.MulticastLoopOption:
return true, nil
+ case tcpip.AcceptConnOption:
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.EndpointState() == StateListen, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -2310,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// done yet) or the reservation was freed between the check above and
// the FindTransportEndpoint below. But rather than retry the same port
// we just skip it and move on.
- transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID())
if transEP == nil {
// ReservePort failed but there is no registered endpoint with
// demuxer. Which indicates there is at least some endpoint that has
@@ -2379,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.ID
- s.route = r.Clone()
e.sndWaker.Assert()
}
}
@@ -2445,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.ID, nil)
+ s := newOutgoingSegment(e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
// Mark endpoint as closed.
@@ -2627,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+
+ // Expand netProtos to include v4 and v6 under dual-stack if the caller is
+ // binding to a wildcard (empty) address, and this is an IPv6 endpoint with
+ // v6only set to false.
+ if netProto == header.IPv6ProtocolNumber {
+ stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
+ alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ if alsoBindToV4 {
+ netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
}
@@ -2715,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
}
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -3074,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() {
}
func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.route.HasHardwareGSOCapability() {
e.initHardwareGSO()
- } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ } else if e.route.HasSoftwareGSOCapability() {
e.gso = &stack.GSO{
MaxSize: e.route.GSOMaxSize(),
Type: stack.GSOSW,
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b25431467..2bcc5e1c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() {
switch {
case epState == StateInitial || epState == StateBound:
case epState.connected() || epState.handshake():
- if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
- if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ if !e.route.HasSaveRestoreCapability() {
+ if !e.route.HasDisconncetOkCapability() {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 070b634b4..0664789da 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -30,6 +30,8 @@ import (
// The canonical way of using it is to pass the Forwarder.HandlePacket function
// to stack.SetTransportProtocolHandler.
type Forwarder struct {
+ stack *stack.Stack
+
maxInFlight int
handler func(*ForwarderRequest)
@@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
+ stack: s,
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
@@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newSegment(r, id, pkt)
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
- if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn {
return false
}
@@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
delete(r.forwarder.inFlight, r.segment.id)
r.forwarder.mu.Unlock()
- // If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
+ replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */)
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 5bce73605..2329aca4b 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(r, ep, id, pkt)
+func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(ep, id, pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newSegment(r, id, pkt)
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
- if !s.parse() || !s.csumValid {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
if !s.flagIsSet(header.TCPFlagRst) {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment, tos, ttl uint8) {
+//
+// If the passed TTL is 0, then the route's default TTL will be used.
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+ route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, tcpFields{
+
+ if ttl == 0 {
+ ttl = route.DefaultTTL()
+ }
+
+ return sendTCP(&route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 7ef2df377..833a7b470 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
return found
}
-// Dump prints the state of the scoreboard structure.
+// String returns human-readable state of the scoreboard structure.
func (s *SACKScoreboard) String() string {
var str strings.Builder
str.WriteString("SACKScoreboard: {")
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1f9c5cf50..2091989cc 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -45,9 +46,18 @@ type segment struct {
ep *endpoint
qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
- route stack.Route `state:"manual"`
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- hdr header.TCP
+
+ // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of
+ // individual members for link/network packet info.
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ netProto tcpip.NetworkProtocolNumber
+ nicID tcpip.NICID
+ remoteLinkAddr tcpip.LinkAddress
+
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -76,11 +86,16 @@ type segment struct {
acked bool
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+ netHdr := pkt.Network()
s := &segment{
- refCnt: 1,
- id: id,
- route: r.Clone(),
+ refCnt: 1,
+ id: id,
+ srcAddr: netHdr.SourceAddress(),
+ dstAddr: netHdr.DestinationAddress(),
+ netProto: pkt.NetworkProtocolNumber,
+ nicID: pkt.NICID,
+ remoteLinkAddr: pkt.SourceLinkAddress(),
}
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
@@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB
return s
}
-func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
- route: r.Clone(),
}
s.rcvdTime = time.Now()
if len(v) != 0 {
@@ -110,7 +124,9 @@ func (s *segment) clone() *segment {
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
- route: s.route.Clone(),
+ netProto: s.netProto,
+ nicID: s.nicID,
+ remoteLinkAddr: s.remoteLinkAddr,
viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
@@ -160,7 +176,6 @@ func (s *segment) decRef() {
panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
}
}
- s.route.Release()
}
}
@@ -198,10 +213,10 @@ func (s *segment) segMemSize() int {
//
// Returns boolean indicating if the parsing was successful.
//
-// If checksum verification is not offloaded then parse also verifies the
+// If checksum verification may not be skipped, parse also verifies the
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
-func (s *segment) parse() bool {
+func (s *segment) parse(skipChecksumValidation bool) bool {
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -220,16 +235,14 @@ func (s *segment) parse() bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
- // Query the link capabilities to decide if checksum validation is
- // required.
verifyChecksum := true
- if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ if skipChecksumValidation {
s.csumValid = true
verifyChecksum = false
}
if verifyChecksum {
s.csum = s.hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 6fa8d63cd..ab5fa4fb7 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
+ if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 {
+ return
+ }
+
// Sort the SACK blocks. The first block is the most recent unacked
// block. The following blocks can be in arbitrary order.
sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index a7149efd0..5f05608e2 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) {
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
@@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
}
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
@@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) {
// Test acceptance.
// Start listening.
- listenBacklog := 2
+ listenBacklog := 10
if err := c.EP.Listen(listenBacklog); err != nil {
t.Fatalf("Listen failed: %s", err)
}
- for i := 0; i < listenBacklog; i++ {
- executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
+ lastPortOffset := uint16(0)
+ for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
time.Sleep(50 * time.Millisecond)
@@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 2,
+ SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(789),
@@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
@@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+ // drain any older notifications from the notification channel before attempting
+ // 2nd connection.
+ select {
+ case <-ch:
+ default:
+ }
+
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
iss = seqnum.Value(792)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 4d7847142..f791f8f13 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -112,6 +112,18 @@ type Headers struct {
TCPOpts []byte
}
+// Options contains options for creating a new test context.
+type Options struct {
+ // EnableV4 indicates whether IPv4 should be enabled.
+ EnableV4 bool
+
+ // EnableV6 indicates whether IPv4 should be enabled.
+ EnableV6 bool
+
+ // MTU indicates the maximum transmission unit on the link layer.
+ MTU uint32
+}
+
// Context provides an initialized Network stack and a link layer endpoint
// for use in TCP tests.
type Context struct {
@@ -154,10 +166,30 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ return NewWithOpts(t, Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: mtu,
})
+}
+
+// NewWithOpts allocates and initializes a test context containing a new
+// stack and a link-layer endpoint with specific options.
+func NewWithOpts(t *testing.T, opts Options) *Context {
+ if opts.MTU == 0 {
+ panic("MTU must be greater than 0")
+ }
+
+ stackOpts := stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ }
+ if opts.EnableV4 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
+ }
+ if opts.EnableV6 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
+ }
+ s := stack.New(stackOpts)
const sendBufferSize = 1 << 20 // 1 MiB
const recvBufferSize = 1 << 20 // 1 MiB
@@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- ep := channel.New(1000, mtu, "")
+ ep := channel.New(1000, opts.MTU, "")
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- opts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
}
opts2 := stack.NICOptions{Name: "nic2"}
if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- v4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: StackAddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
- }
-
- v6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: StackV6AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
- }
+ var routeTable []tcpip.Route
- s.SetRouteTable([]tcpip.Route{
- {
+ if opts.EnableV4 {
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
- },
- {
+ })
+ }
+
+ if opts.EnableV6 {
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
NIC: 1,
- },
- })
+ })
+ }
+
+ s.SetRouteTable(routeTable)
return &Context{
t: t,
@@ -373,6 +410,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code,
const icmpv4VariableHeaderOffset = 4
copy(icmp[icmpv4VariableHeaderOffset:], p1)
copy(icmp[header.ICMPv4PayloadOffset:], p2)
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
// Inject packet.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d57ed5d79..9bcb918bb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
+ if to.Port == 0 {
+ // Port 0 is an invalid port to send to.
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -895,6 +900,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.AcceptConnOption:
+ return false, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1009,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
@@ -1366,6 +1374,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.rcvMu.Unlock()
}
+ e.lastErrorMu.Lock()
+ hasError := e.lastError != nil
+ e.lastErrorMu.Unlock()
+ if hasError {
+ result |= waiter.EventErr
+ }
return result
}
@@ -1373,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
-func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool {
- if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
- (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
+ if !pkt.RXTransportChecksumValidated &&
+ (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
+ netHdr := pkt.Network()
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
}
@@ -1387,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
@@ -1397,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- if !verifyChecksum(r, hdr, pkt) {
+ if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
e.stats.ReceiveErrors.ChecksumErrors.Increment()
@@ -1428,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
@@ -1438,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvBufSize += pkt.Data.Size()
// Save any useful information from the network header to the packet.
- switch r.NetProto {
+ switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
@@ -1448,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
// address. packetInfo.LocalAddr should hold a unicast address that can be
// used to respond to the incoming packet.
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.LocalAddress
- packet.packetInfo.NIC = r.NICID()
+ localAddr := pkt.Network().DestinationAddress()
+ packet.packetInfo.LocalAddr = localAddr
+ packet.packetInfo.DestinationAddr = localAddr
+ packet.packetInfo.NIC = pkt.NICID
packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1465,14 +1481,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
- defer e.mu.RUnlock()
-
if e.state == StateConnected {
e.lastErrorMu.Lock()
- defer e.lastErrorMu.Unlock()
-
e.lastError = tcpip.ErrConnectionRefused
+ e.lastErrorMu.Unlock()
+ e.mu.RUnlock()
+
+ e.waiterQueue.Notify(waiter.EventErr)
+ return
}
+ e.mu.RUnlock()
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 3ae6cc221..14e4648cd 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
- route: r,
id: id,
pkt: pkt,
})
@@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p
// it via CreateEndpoint.
type ForwarderRequest struct {
stack *stack.Stack
- route *stack.Route
id stack.TransportEndpointID
pkt *stack.PacketBuffer
}
@@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ netHdr := r.pkt.Network()
+ route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(r.pkt.SourceLinkAddress())
+
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
+ route.Release()
return nil, err
}
ep.ID = r.id
- ep.route = r.route.Clone()
+ ep.route = route
ep.dstPort = r.id.RemotePort
- ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto}
- ep.RegisterNICID = r.route.NICID()
+ ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
+ ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
@@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.pkt)
+ ep.HandlePacket(r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index da5b1deb2..91420edd3 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets that are targeted at this
// protocol but don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ p.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return stack.UnknownDestinationPacketMalformed
}
- if !verifyChecksum(r, hdr, pkt) {
- r.Stack().Stats().UDP.ChecksumErrors.Increment()
+ if !verifyChecksum(hdr, pkt) {
+ p.stack.Stats().UDP.ChecksumErrors.Increment()
return stack.UnknownDestinationPacketMalformed
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b4604ba35..fb7738dda 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) {
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
-
}
// In the case of large payloads the IP packet may be truncated. Update
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index 5c4b9e8e9..a38ffc19d 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -53,40 +53,40 @@ func randomFilename() (string, error) {
func TestConnectFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
if _, err := Connect(name, false); err == nil {
- t.Fatalf("connect was successful, expected err")
+ t.Fatalf("Connect was successful, expected err")
}
}
func TestBindFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
if _, err = BindAndListen(name, false); err == nil {
- t.Fatalf("second bind succeeded, expected non-nil err")
+ t.Fatalf("Second bind succeeded, expected non-nil err")
}
}
func TestMultipleAccept(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
@@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) {
defer wg.Done()
s, err := Connect(name, false)
if err != nil {
- t.Fatalf("connect failed, got err %v expected nil", err)
+ t.Errorf("Connect failed, got err %v expected nil", err)
+ return
}
s.Close()
}()
@@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) {
for i := 0; i < backlog; i++ {
s, err := ss.Accept()
if err != nil {
- t.Errorf("accept failed, got err %v expected nil", err)
+ t.Errorf("Accept failed, got err %v expected nil", err)
continue
}
s.Close()
@@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) {
func TestServerClose(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
// Make sure the first close succeeds.
if err := ss.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := ss.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
// Bind a server.
ss, err := BindAndListen(name, packet)
if err != nil {
- t.Fatalf("error binding, got %v expected nil", err)
+ t.Fatalf("Error binding, got %v expected nil", err)
}
defer ss.Close()
@@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
// Connect the client.
client, err := Connect(name, packet)
if err != nil {
- t.Fatalf("error connecting, got %v expected nil", err)
+ t.Fatalf("Error connecting, got %v expected nil", err)
}
// Grab the server handle.
@@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
case server := <-acceptSocket:
return server, client
case err := <-acceptErr:
- t.Fatalf("accept error: %v", err)
+ t.Fatalf("Accept error: %v", err)
}
panic("unreachable")
}
@@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) {
// Write on the server.
w := server.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -233,13 +234,13 @@ func TestPacket(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Write on the client again.
w = client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
@@ -249,19 +250,19 @@ func TestPacket(t *testing.T) {
b := [][]byte{{'b', 'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Do it again.
r = server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -271,12 +272,12 @@ func TestClose(t *testing.T) {
// Make sure the first close succeeds.
if err := client.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := client.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
@@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) {
// We're good. That's what we wanted.
blockCount++
} else {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
}
}
}
if blockCount == 1000 {
// Shouldn't have _always_ blocked.
- t.Fatalf("socket always blocked!")
+ t.Fatalf("Socket always blocked!")
} else if blockCount == 0 {
// Should have started blocking eventually.
- t.Fatalf("socket never blocked!")
+ t.Fatalf("Socket never blocked!")
}
}
@@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) {
// Expected to block immediately.
_, err := r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
// Put some data in the pipe.
w := server.Writer(false)
if n, err := w.WriteVec(b); n != 1 || err != nil {
- t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it not to block.
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it to return a block error again.
r = client.Reader(false)
_, err = r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
}
@@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c'}, {'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[1][0] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
}
}
@@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c', 'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[0][1] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
}
}
@@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) {
w := server.Writer(true)
w.PackFDs(0, 1, 2)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client, without enabling FDs.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Make sure the FDs are not received.
fds, err := r.ExtractFDs()
if len(fds) != 0 || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
}
}
@@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) {
w := s.Writer(true)
w.PackFDs(fds...)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
}
@@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Count the number of FDs.
preEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
// Read on the client.
@@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
r.EnableFDs(enableSize)
}
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Count the new number of FDs.
postEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(preEntries)+expected != len(postEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
}
// Make sure the FDs are there.
fds, err := r.ExtractFDs()
if len(fds) != expected || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
}
// Make sure they are different from the originals.
for i := 0; i < len(fds); i++ {
if fds[i] == origFDs[i] {
- t.Errorf("got original fd for index %d, expected different", i)
+ t.Errorf("Got original fd for index %d, expected different", i)
}
}
@@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Make sure the count is back to normal.
finalEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(finalEntries) != len(preEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
}
}
@@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) {
}
if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
+ t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
}
}
@@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) {
want := "bad file descriptor"
if _, err := s.GetPeerCred(); err == nil || err.Error() != want {
- t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want)
+ t.Errorf("s.GetPeerCred() = %v, want = %s", err, want)
}
}
func TestAcceptClosed(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Fatalf("Close failed, got err %v expected nil", err)
}
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
}
func TestCloseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Errorf("Close failed, got err %v expected nil", err)
}
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) {
func TestReleaseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
fd, err := ss.Release()
if err != nil {
- t.Fatalf("Release failed, got err %v expected nil", err)
+ t.Errorf("Release failed, got err %v expected nil", err)
}
syscall.Close(fd)
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) {
cm.PackFDs(want...)
got, err := cm.ExtractFDs()
if err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
+ t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
}
}
}
@@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) {
for i := 0; i < b.N; i++ {
n, err := server.Read(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
n, err = server.Write(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
}
}()
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 67a950444..08519d986 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
//
// +stateify savable
type Queue struct {
- list waiterList `state:"zerovalue"`
+ list waiterList
mu sync.RWMutex `state:"nosave"`
}